-
Notifications
You must be signed in to change notification settings - Fork 18
Fail to get group attributes in VectorDisassembler.transformSchema #4
Copy link
Copy link
Open
Description
Hi,
I'm trying to use the VectorDisassembler to disassemble the one-hot encoded vector, the code below shows that the stage of VectorDisassembler fail to generate accurate schema
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{OneHotEncoder, RFormula, StringIndexer, VectorDisassembler}
import org.apache.spark.sql.SparkSession
object Test {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("test").getOrCreate()
// https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv
val df = spark.read.option("header", "true").option("inferSchema", "true").csv("iris.csv")
val indexer = new StringIndexer().setInputCol("species").setOutputCol("species_idx")
val encoder = new OneHotEncoder().setInputCol(indexer.getOutputCol).setOutputCol("species_enc")
val disassembler = new VectorDisassembler().setInputCol(encoder.getOutputCol)
val formula = new RFormula().setFormula("petal_width ~ petal_length + versicolor + virginica")
val pipeline = new Pipeline().setStages(Array(indexer, encoder, disassembler, formula))
val model = pipeline.fit(df)
// java.lang.IllegalArgumentException: Field "versicolor" does not exist.
model.transform(df)
}
}
I think it's the transformSchema that cause the error, the scripts below show the diff:
scala> new Pipeline().setStages(Array(indexer, encoder, disassembler)).fit(df).transform(df).schema
res1: org.apache.spark.sql.types.StructType = StructType(StructField(sepal_length,DoubleType,true), StructField(sepal_width,DoubleType,true), StructField(petal_length,DoubleType,true), StructField(petal_width,DoubleType,true), StructField(species,StringType,true), StructField(species_idx,DoubleType,true), StructField(species_enc,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true), StructField(versicolor,DoubleType,true), StructField(virginica,DoubleType,true))
scala> new Pipeline().setStages(Array(indexer, encoder, disassembler)).fit(df).transformSchema(df.schema)
res2: org.apache.spark.sql.types.StructType = StructType(StructField(sepal_length,DoubleType,true), StructField(sepal_width,DoubleType,true), StructField(petal_length,DoubleType,true), StructField(petal_width,DoubleType,true), StructField(species,StringType,true), StructField(species_idx,DoubleType,false), StructField(species_enc,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,false), StructField(species_enc_0,DoubleType,false))
It's obvious transform().schema get versicolor and virginica while the transformSchema get species_enc_0
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels