Skip to content

Fail to get group attributes in VectorDisassembler.transformSchema #4

@shiyuan

Description

@shiyuan

Hi,

I'm trying to use the VectorDisassembler to disassemble the one-hot encoded vector, the code below shows that the stage of VectorDisassembler fail to generate accurate schema

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{OneHotEncoder, RFormula, StringIndexer, VectorDisassembler}
import org.apache.spark.sql.SparkSession

object Test {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("test").getOrCreate()

    // https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv
    val df = spark.read.option("header", "true").option("inferSchema", "true").csv("iris.csv")
    val indexer = new StringIndexer().setInputCol("species").setOutputCol("species_idx")
    val encoder = new OneHotEncoder().setInputCol(indexer.getOutputCol).setOutputCol("species_enc")
    val disassembler = new VectorDisassembler().setInputCol(encoder.getOutputCol)
    val formula = new RFormula().setFormula("petal_width ~ petal_length + versicolor + virginica")
    val pipeline = new Pipeline().setStages(Array(indexer, encoder, disassembler, formula))
    val model = pipeline.fit(df)

    // java.lang.IllegalArgumentException: Field "versicolor" does not exist.
    model.transform(df)
  }
}

I think it's the transformSchema that cause the error, the scripts below show the diff:

scala> new Pipeline().setStages(Array(indexer, encoder, disassembler)).fit(df).transform(df).schema
res1: org.apache.spark.sql.types.StructType = StructType(StructField(sepal_length,DoubleType,true), StructField(sepal_width,DoubleType,true), StructField(petal_length,DoubleType,true), StructField(petal_width,DoubleType,true), StructField(species,StringType,true), StructField(species_idx,DoubleType,true), StructField(species_enc,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true), StructField(versicolor,DoubleType,true), StructField(virginica,DoubleType,true))

scala> new Pipeline().setStages(Array(indexer, encoder, disassembler)).fit(df).transformSchema(df.schema)
res2: org.apache.spark.sql.types.StructType = StructType(StructField(sepal_length,DoubleType,true), StructField(sepal_width,DoubleType,true), StructField(petal_length,DoubleType,true), StructField(petal_width,DoubleType,true), StructField(species,StringType,true), StructField(species_idx,DoubleType,false), StructField(species_enc,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,false), StructField(species_enc_0,DoubleType,false))

It's obvious transform().schema get versicolor and virginica while the transformSchema get species_enc_0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions