Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions docs/source/contributor-guide/spark_expressions_support.md

Large diffs are not rendered by default.

198 changes: 103 additions & 95 deletions spark/src/main/scala/org/apache/comet/serde/strings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,54 +76,54 @@ object CometUpper extends CometCaseConversionBase[Upper]("upper")
object CometLower extends CometCaseConversionBase[Lower]("lower")

object CometLength extends CometScalarFunction[Length]("length") {
override def getUnsupportedReasons(): Seq[String] = Seq("`BinaryType` input is not supported")
private val binaryUnsupportedReason = "`BinaryType` input is not supported"

override def getUnsupportedReasons(): Seq[String] = Seq(binaryUnsupportedReason)

override def getSupportLevel(expr: Length): SupportLevel = expr.child.dataType match {
case _: BinaryType => Unsupported(Some("Length on BinaryType is not supported"))
case _: BinaryType => Unsupported(Some(binaryUnsupportedReason))
case _ => Compatible()
}
}

object CometInitCap extends CometScalarFunction[InitCap]("initcap") {

override def getIncompatibleReasons(): Seq[String] = Seq(
private val incompatReason =
"Treats hyphen as a word separator (e.g. `robert rose-smith` produces `Robert Rose-Smith`" +
" instead of Spark's `Robert Rose-smith`)" +
" (https://github.com/apache/datafusion-comet/issues/1052)")
" (https://github.com/apache/datafusion-comet/issues/1052)"

override def getSupportLevel(expr: InitCap): SupportLevel = {
// Behavior differs from Spark. One example is that for the input "robert rose-smith", Spark
// will produce "Robert Rose-smith", but Comet will produce "Robert Rose-Smith".
// https://github.com/apache/datafusion-comet/issues/1052
Incompatible(None)
}
override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason)

override def convert(expr: InitCap, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
super.convert(expr, inputs, binding)
}
override def getSupportLevel(expr: InitCap): SupportLevel = Incompatible(Some(incompatReason))
}

object CometSubstring extends CometExpressionSerde[Substring] {

private val literalArgsReason = "`pos` and `len` arguments must be literal values"

override def getUnsupportedReasons(): Seq[String] = Seq(literalArgsReason)

override def getSupportLevel(expr: Substring): SupportLevel = (expr.pos, expr.len) match {
case (_: Literal, _: Literal) => Compatible()
case _ => Unsupported(Some(literalArgsReason))
}

override def convert(
expr: Substring,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
(expr.pos, expr.len) match {
case (Literal(pos, _), Literal(len, _)) =>
exprToProtoInternal(expr.str, inputs, binding) match {
case Some(strExpr) =>
val builder = ExprOuterClass.Substring.newBuilder()
builder.setChild(strExpr)
builder.setStart(pos.asInstanceOf[Int])
builder.setLen(len.asInstanceOf[Int])
Some(ExprOuterClass.Expr.newBuilder().setSubstring(builder).build())
case None =>
withInfo(expr, expr.str)
None
}
case _ =>
withInfo(expr, "Substring pos and len must be literals")
val Literal(pos, _) = expr.pos
val Literal(len, _) = expr.len
exprToProtoInternal(expr.str, inputs, binding) match {
case Some(strExpr) =>
val builder = ExprOuterClass.Substring.newBuilder()
builder.setChild(strExpr)
builder.setStart(pos.asInstanceOf[Int])
builder.setLen(len.asInstanceOf[Int])
Some(ExprOuterClass.Expr.newBuilder().setSubstring(builder).build())
case None =>
withInfo(expr, expr.str)
None
}
}
Expand All @@ -147,78 +147,82 @@ object CometSubstringIndex extends CometExpressionSerde[SubstringIndex] {

object CometLeft extends CometExpressionSerde[Left] {

override def getUnsupportedReasons(): Seq[String] = Seq(
"Only supports `BinaryType` and `StringType` input",
"The length argument must be a literal value")
private val literalLenReason = "The `length` argument must be a literal value"
private val unsupportedDataTypeReason = "Only supports `BinaryType` and `StringType` input"

override def getUnsupportedReasons(): Seq[String] =
Seq(unsupportedDataTypeReason, literalLenReason)

override def convert(expr: Left, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
expr.len match {
case Literal(lenValue, _) =>
exprToProtoInternal(expr.str, inputs, binding) match {
case Some(strExpr) =>
val builder = ExprOuterClass.Substring.newBuilder()
builder.setChild(strExpr)
builder.setStart(1)
builder.setLen(lenValue.asInstanceOf[Int])
Some(ExprOuterClass.Expr.newBuilder().setSubstring(builder).build())
case None =>
withInfo(expr, expr.str)
None
}
case _ =>
withInfo(expr, "LEFT len must be a literal")
val Literal(lenValue, _) = expr.len
exprToProtoInternal(expr.str, inputs, binding) match {
case Some(strExpr) =>
val builder = ExprOuterClass.Substring.newBuilder()
builder.setChild(strExpr)
builder.setStart(1)
builder.setLen(lenValue.asInstanceOf[Int])
Some(ExprOuterClass.Expr.newBuilder().setSubstring(builder).build())
case None =>
withInfo(expr, expr.str)
None
}
}

override def getSupportLevel(expr: Left): SupportLevel = {
expr.str.dataType match {
case _: BinaryType | _: StringType => Compatible()
case _: BinaryType | _: StringType =>
if (!expr.len.isInstanceOf[Literal]) {
Unsupported(Some(literalLenReason))
} else {
Compatible()
}
case _ => Unsupported(Some(s"LEFT does not support ${expr.str.dataType}"))
}
}
}

object CometRight extends CometExpressionSerde[Right] {

private val literalLenReason = "The `length` argument must be a literal value"
private val unsupportedDataTypeReason = "Only supports `StringType` input"

override def convert(expr: Right, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
expr.len match {
case Literal(lenValue, _) =>
val lenInt = lenValue.asInstanceOf[Int]
if (lenInt <= 0) {
// Match Spark's behavior: If(IsNull(str), NULL, "")
// This ensures NULL propagation: RIGHT(NULL, 0) -> NULL, RIGHT("hello", 0) -> ""
val isNullExpr = IsNull(expr.str)
val nullLiteral = Literal.create(null, StringType)
val emptyStringLiteral = Literal(UTF8String.EMPTY_UTF8, StringType)
val ifExpr = If(isNullExpr, nullLiteral, emptyStringLiteral)

// Serialize the If expression using existing infrastructure
exprToProtoInternal(ifExpr, inputs, binding)
} else {
exprToProtoInternal(expr.str, inputs, binding) match {
case Some(strExpr) =>
val builder = ExprOuterClass.Substring.newBuilder()
builder.setChild(strExpr)
builder.setStart(-lenInt)
builder.setLen(lenInt)
Some(ExprOuterClass.Expr.newBuilder().setSubstring(builder).build())
case None =>
withInfo(expr, expr.str)
None
}
}
case _ =>
withInfo(expr, "RIGHT len must be a literal")
None
val Literal(lenValue, _) = expr.len
val lenInt = lenValue.asInstanceOf[Int]
if (lenInt <= 0) {
// Match Spark's behavior: If(IsNull(str), NULL, "")
// This ensures NULL propagation: RIGHT(NULL, 0) -> NULL, RIGHT("hello", 0) -> ""
val isNullExpr = IsNull(expr.str)
val nullLiteral = Literal.create(null, StringType)
val emptyStringLiteral = Literal(UTF8String.EMPTY_UTF8, StringType)
val ifExpr = If(isNullExpr, nullLiteral, emptyStringLiteral)
exprToProtoInternal(ifExpr, inputs, binding)
} else {
exprToProtoInternal(expr.str, inputs, binding) match {
case Some(strExpr) =>
val builder = ExprOuterClass.Substring.newBuilder()
builder.setChild(strExpr)
builder.setStart(-lenInt)
builder.setLen(lenInt)
Some(ExprOuterClass.Expr.newBuilder().setSubstring(builder).build())
case None =>
withInfo(expr, expr.str)
None
}
}
}

override def getUnsupportedReasons(): Seq[String] = Seq("Only supports `StringType` input")
override def getUnsupportedReasons(): Seq[String] =
Seq(unsupportedDataTypeReason, literalLenReason)

override def getSupportLevel(expr: Right): SupportLevel = {
expr.str.dataType match {
case _: StringType => Compatible()
case _: StringType =>
if (!expr.len.isInstanceOf[Literal]) {
Unsupported(Some(literalLenReason))
} else {
Compatible()
}
case _ => Unsupported(Some(s"RIGHT does not support ${expr.str.dataType}"))
}
}
Expand Down Expand Up @@ -309,18 +313,22 @@ object CometRLike extends CometExpressionSerde[RLike] {
}
}

private object PadReasons {
val literalStrReason = "Scalar values are not supported for the `str` argument."
val nonLiteralPadReason = "Only scalar values are supported for the `pad` argument."
}

object CometStringRPad extends CometExpressionSerde[StringRPad] {

override def getUnsupportedReasons(): Seq[String] = Seq(
"Scalar values are not supported for the `str` argument." +
" Only scalar values are supported for the `pad` argument.")
override def getUnsupportedReasons(): Seq[String] =
Seq(PadReasons.literalStrReason, PadReasons.nonLiteralPadReason)

override def getSupportLevel(expr: StringRPad): SupportLevel = {
if (expr.str.isInstanceOf[Literal]) {
return Unsupported(Some("Scalar values are not supported for the str argument"))
return Unsupported(Some(PadReasons.literalStrReason))
}
if (!expr.pad.isInstanceOf[Literal]) {
return Unsupported(Some("Only scalar values are supported for the pad argument"))
return Unsupported(Some(PadReasons.nonLiteralPadReason))
}
Compatible()
}
Expand All @@ -340,16 +348,15 @@ object CometStringRPad extends CometExpressionSerde[StringRPad] {

object CometStringLPad extends CometExpressionSerde[StringLPad] {

override def getUnsupportedReasons(): Seq[String] = Seq(
"Scalar values are not supported for the `str` argument." +
" Only scalar values are supported for the `pad` argument.")
override def getUnsupportedReasons(): Seq[String] =
Seq(PadReasons.literalStrReason, PadReasons.nonLiteralPadReason)

override def getSupportLevel(expr: StringLPad): SupportLevel = {
if (expr.str.isInstanceOf[Literal]) {
return Unsupported(Some("Scalar values are not supported for the str argument"))
return Unsupported(Some(PadReasons.literalStrReason))
}
if (!expr.pad.isInstanceOf[Literal]) {
return Unsupported(Some("Only scalar values are supported for the pad argument"))
return Unsupported(Some(PadReasons.nonLiteralPadReason))
}
Compatible()
}
Expand All @@ -367,11 +374,13 @@ object CometStringLPad extends CometExpressionSerde[StringLPad] {
}

object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] {
override def getIncompatibleReasons(): Seq[String] = Seq(
"Regexp pattern may not be compatible with Spark")
private val incompatReason = "Regexp pattern may not be compatible with Spark"
private val offsetUnsupportedReason =
"Only supports `regexp_replace` with an offset of 1 (no offset)"

override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason)

override def getUnsupportedReasons(): Seq[String] = Seq(
"Only supports `regexp_replace` with an offset of 1 (no offset)")
override def getUnsupportedReasons(): Seq[String] = Seq(offsetUnsupportedReason)

override def getSupportLevel(expr: RegExpReplace): SupportLevel = {
if (!RegExp.isSupportedPattern(expr.regexp.toString) &&
Expand All @@ -381,12 +390,11 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] {
s"Regexp pattern ${expr.regexp} is not compatible with Spark. " +
s"Set ${CometConf.getExprAllowIncompatConfigKey("regexp")}=true " +
"to allow it anyway.")
return Incompatible()
return Incompatible(Some(incompatReason))
}
expr.pos match {
case Literal(value, DataTypes.IntegerType) if value == 1 => Compatible()
case _ =>
Unsupported(Some("Comet only supports regexp_replace with an offset of 1 (no offset)."))
case _ => Unsupported(Some(offsetUnsupportedReason))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ CREATE TABLE test_str_left(s string, n int) USING parquet
statement
INSERT INTO test_str_left VALUES ('hello', 3), ('hello', 0), ('hello', -1), ('hello', 10), ('', 3), (NULL, 3), ('hello', NULL)

query expect_fallback(Substring pos and len must be literals)
query expect_fallback(arguments must be literal values)
SELECT left(s, n) FROM test_str_left

-- column + literal
Expand All @@ -40,7 +40,7 @@ query
SELECT left(s, 10) FROM test_str_left

-- literal + column
query expect_fallback(Substring pos and len must be literals)
query expect_fallback(arguments must be literal values)
SELECT left('hello', n) FROM test_str_left

-- literal + literal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ CREATE TABLE test_lpad(s string, len int, pad string) USING parquet
statement
INSERT INTO test_lpad VALUES ('hi', 5, 'x'), ('hello', 3, 'x'), ('hi', 5, 'xy'), ('', 3, 'a'), (NULL, 5, 'x'), ('hi', 0, 'x'), ('hi', -1, 'x')

query expect_fallback(Only scalar values are supported for the pad argument)
query expect_fallback(Only scalar values are supported for the `pad` argument)
SELECT lpad(s, len, pad) FROM test_lpad

query
Expand All @@ -32,5 +32,5 @@ query
SELECT lpad(s, 5, 'x') FROM test_lpad

-- literal + literal + literal
query expect_fallback(Scalar values are not supported for the str argument)
query expect_fallback(Scalar values are not supported for the `str` argument)
SELECT lpad('hi', 5, 'x'), lpad('hello', 3, 'x'), lpad('', 3, 'a'), lpad(NULL, 5, 'x')
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ CREATE TABLE test_rpad(s string, len int, pad string) USING parquet
statement
INSERT INTO test_rpad VALUES ('hi', 5, 'x'), ('hello', 3, 'x'), ('hi', 5, 'xy'), ('', 3, 'a'), (NULL, 5, 'x'), ('hi', 0, 'x'), ('hi', -1, 'x')

query expect_fallback(Only scalar values are supported for the pad argument)
query expect_fallback(Only scalar values are supported for the `pad` argument)
SELECT rpad(s, len, pad) FROM test_rpad

query
Expand All @@ -32,5 +32,5 @@ query
SELECT rpad(s, 5, 'x') FROM test_rpad

-- literal + literal + literal
query expect_fallback(Scalar values are not supported for the str argument)
query expect_fallback(Scalar values are not supported for the `str` argument)
SELECT rpad('hi', 5, 'x'), rpad('hello', 3, 'x'), rpad('', 3, 'a'), rpad(NULL, 5, 'x')
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ class CometStringExpressionSuite extends CometTestBase {
} else if (isLiteralStr) {
checkSparkAnswerAndFallbackReason(
sql,
"Scalar values are not supported for the str argument")
"Scalar values are not supported for the `str` argument")
} else if (!isLiteralPad) {
checkSparkAnswerAndFallbackReason(
sql,
"Only scalar values are supported for the pad argument")
"Only scalar values are supported for the `pad` argument")
} else {
checkSparkAnswerAndOperator(sql)
}
Expand Down
Loading