Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val DIVIDE_BY_ZERO_EXCEPTION_MSG =
"""Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""

// Temporary test to verify checkSparkAnswer failure output labels Comet/Spark correctly.
ignore("check output labels on mismatch") {
val cometDf = Seq((1, "apple"), (2, "banana"), (3, "cherry")).toDF("id", "fruit")
val sparkAnswer = Seq(Row(1, "apple"), Row(2, "BANANA"), Row(3, "cherry"))
checkCometAnswer(cometDf, sparkAnswer)
}

test("sort floating point with negative zero") {
val schema = StructType(
Seq(
Expand Down
46 changes: 45 additions & 1 deletion spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import org.apache.parquet.hadoop.example.{ExampleParquetWriter, GroupWriteSuppor
import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark._
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.comet.CometPlanChecker
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -128,7 +130,7 @@ abstract class CometTestBase
if (withTol.isDefined) {
checkAnswerWithTolerance(dfComet, expected, withTol.get)
} else {
checkAnswer(dfComet, expected)
checkCometAnswer(dfComet, expected)
}

if (assertCometNative) {
Expand Down Expand Up @@ -358,6 +360,48 @@ abstract class CometTestBase
}
}

/**
* Compares the Comet DataFrame result against the expected Spark answer, using labels that
* correctly identify which side is Comet and which is Spark. This avoids the misleading "Spark
* Answer" label that Spark's built-in `checkAnswer` would apply to the Comet result.
*/
protected def checkCometAnswer(cometDf: DataFrame, sparkAnswer: Seq[Row]): Unit = {
val isSorted = cometDf.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
val cometAnswer =
try cometDf.collect().toSeq
catch {
case e: Exception =>
fail(s"""Exception thrown while executing query in Comet:
|${cometDf.queryExecution}
|== Exception ==
|$e
|${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin)
}
if (!QueryTest.compare(
QueryTest.prepareAnswer(sparkAnswer, isSorted),
QueryTest.prepareAnswer(cometAnswer, isSorted))) {
val getRowType: Option[Row] => String = row =>
row
.map(r => if (r.schema == null) "struct<>" else r.schema.catalogString)
.getOrElse("struct<>")
fail(s"""Results do not match for query:
|Timezone: ${java.util.TimeZone.getDefault}
|Timezone Env: ${sys.env.getOrElse("TZ", "")}
|
|${cometDf.queryExecution}
|== Results ==
|${sideBySide(
s"== Spark Answer - ${sparkAnswer.size} ==" +:
getRowType(sparkAnswer.headOption) +:
QueryTest.prepareAnswer(sparkAnswer, isSorted).map(_.toString()),
s"== Comet Answer - ${cometAnswer.size} ==" +:
getRowType(cometAnswer.headOption) +:
QueryTest.prepareAnswer(cometAnswer, isSorted).map(_.toString())).mkString("\n")}
""".stripMargin)
}
}

/**
* A helper function for comparing Comet DataFrame with Spark result using absolute tolerance.
*/
Expand Down
Loading