Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/source/user-guide/latest/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ the [Comet Supported Expressions Guide](expressions.md) for more information on
timezone is UTC.
[#2649](https://github.com/apache/datafusion-comet/issues/2649)

### Aggregate Expressions

- **Corr**: Returns null instead of NaN in some edge cases.
[#2646](https://github.com/apache/datafusion-comet/issues/2646)

### Struct Expressions

- **StructsToJson (to_json)**: Does not support `+Infinity` and `-Infinity` for numeric types (float, double).
Expand Down
42 changes: 21 additions & 21 deletions docs/source/user-guide/latest/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,27 +195,27 @@ Expressions that are not Spark-compatible will fall back to Spark by default and

## Aggregate Expressions

| Expression | SQL | Spark-Compatible? | Compatibility Notes |
| ------------- | ---------- | ------------------------- | ---------------------------------------------------------------------------------------------------------------- |
| Average | | Yes, except for ANSI mode | |
| BitAndAgg | | Yes | |
| BitOrAgg | | Yes | |
| BitXorAgg | | Yes | |
| BoolAnd | `bool_and` | Yes | |
| BoolOr | `bool_or` | Yes | |
| Corr | | No | Returns null instead of NaN in some edge cases ([#2646](https://github.com/apache/datafusion-comet/issues/2646)) |
| Count | | Yes | |
| CovPopulation | | Yes | |
| CovSample | | Yes | |
| First | | No | This function is not deterministic. Results may not match Spark. |
| Last | | No | This function is not deterministic. Results may not match Spark. |
| Max | | Yes | |
| Min | | Yes | |
| StddevPop | | Yes | |
| StddevSamp | | Yes | |
| Sum | | Yes, except for ANSI mode | |
| VariancePop | | Yes | |
| VarianceSamp | | Yes | |
| Expression | SQL | Spark-Compatible? | Compatibility Notes |
| ------------- | ---------- | ------------------------- | ---------------------------------------------------------------- |
| Average | | Yes, except for ANSI mode | |
| BitAndAgg | | Yes | |
| BitOrAgg | | Yes | |
| BitXorAgg | | Yes | |
| BoolAnd | `bool_and` | Yes | |
| BoolOr | `bool_or` | Yes | |
| Corr | | Yes | |
| Count | | Yes | |
| CovPopulation | | Yes | |
| CovSample | | Yes | |
| First | | No | This function is not deterministic. Results may not match Spark. |
| Last | | No | This function is not deterministic. Results may not match Spark. |
| Max | | Yes | |
| Min | | Yes | |
| StddevPop | | Yes | |
| StddevSamp | | Yes | |
| Sum | | Yes, except for ANSI mode | |
| VariancePop | | Yes | |
| VarianceSamp | | Yes | |

## Window Functions

Expand Down
16 changes: 7 additions & 9 deletions spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, If, IsNaN, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType}
Expand Down Expand Up @@ -584,20 +584,18 @@ object CometStddevPop extends CometAggregateExpressionSerde[StddevPop] with Come
}

object CometCorr extends CometAggregateExpressionSerde[Corr] {

override def getSupportLevel(expr: Corr): SupportLevel =
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude flagged some edge cases we can document -

 ▎ 1. Legacy mode: When spark.sql.legacy.statisticalAggregate=true, nullOnDivideByZero is false and Spark returns NaN for the n=1 case. With this workaround, Comet would return null instead (because the NaN row gets skipped → n=0). Should we add a getSupportLevel guard that returns Incompatible when
  corr.nullOnDivideByZero is false? Or at least document this?
 ▎ 2. Mixed groups: For a group containing (NaN, NaN) alongside valid pairs like (1.0, 2.0), Spark returns NaN (NaN contaminates the accumulator), while this workaround would skip the NaN row and compute a valid correlation over the remaining rows. Is that a known limitation we're OK with?

Incompatible(
Some(
"Returns null instead of NaN in some edge cases" +
" (https://github.com/apache/datafusion-comet/issues/2646)"))

override def convert(
aggExpr: AggregateExpression,
corr: Corr,
inputs: Seq[Attribute],
binding: Boolean,
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
val child1Expr = exprToProto(corr.x, inputs, binding)
// When both inputs are NaN, convert one input to null in order to return null.
// This matches Spark's behavior where corr(NaN, NaN) returns null.
val wrappedX =
If(And(IsNaN(corr.x), IsNaN(corr.y)), Literal.create(null, corr.x.dataType), corr.x)

val child1Expr = exprToProto(wrappedX, inputs, binding)
val child2Expr = exprToProto(corr.y, inputs, binding)
val dataType = serializeDataType(corr.dataType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
-- specific language governing permissions and limitations
-- under the License.

-- Config: spark.comet.expression.Corr.allowIncompatible=true
-- ConfigMatrix: parquet.enable.dictionary=false,true

statement
Expand All @@ -29,3 +28,13 @@ SELECT corr(x, y) FROM test_corr

query tolerance=1e-6
SELECT grp, corr(x, y) FROM test_corr GROUP BY grp ORDER BY grp

-- Test permutations of NULL and NaN
statement
CREATE TABLE test_corr_nan(x double, y double, grp string) USING parquet

statement
INSERT INTO test_corr_nan VALUES (cast('NaN' as double), cast('NaN' as double), 'both_nan'), (cast('NaN' as double), 1.0, 'nan_val'), (1.0, cast('NaN' as double), 'val_nan'), (NULL, cast('NaN' as double), 'null_nan'), (cast('NaN' as double), NULL, 'nan_null'), (NULL, NULL, 'both_null'), (NULL, 1.0, 'null_val'), (1.0, NULL, 'val_null')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a group with mixed nan and valid rows ( eg [(NaN, NaN), (1.0, 2.0), (3.0, 4.0)] )


query tolerance=1e-6
SELECT grp, corr(x, y) FROM test_corr_nan GROUP BY grp ORDER BY grp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import scala.util.Random
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.expressions.aggregate.Corr
import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
Expand Down Expand Up @@ -1306,9 +1305,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("covariance & correlation") {
withSQLConf(
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.getExprAllowIncompatConfigKey(classOf[Corr]) -> "true") {
withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
Seq("jvm", "native").foreach { cometShuffleMode =>
withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> cometShuffleMode) {
Seq(true, false).foreach { dictionary =>
Expand Down Expand Up @@ -1379,6 +1376,27 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("corr - nan/null") {
withTable("t") {
sql("""create table t using parquet as
select cast(null as float) f1, CAST('NaN' AS float) f2, cast(null as double) d1, CAST('NaN' AS double) d2
from range(1)
""")

checkSparkAnswerAndOperator("""
|select
| corr(f1, f2) c1,
| corr(f1, f1) c2,
| corr(f2, f1) c3,
| corr(f2, f2) c4,
| corr(d1, d2) c5,
| corr(d1, d1) c6,
| corr(d2, d1) c7,
| corr(d2, d2) c8
| FROM t""".stripMargin)
}
}

test("var_pop and var_samp") {
withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
Seq("native", "jvm").foreach { cometShuffleMode =>
Expand Down
Loading