Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,22 @@ public static StringSearch getStringSearch(
return new StringSearch(patternString, target, (RuleBasedCollator) collator);
}

/**
* Returns a StringSearch object for the given pattern string only (with a placeholder target),
* under collation rules corresponding to the given collationId. The returned StringSearch can be
* reused across multiple target strings by calling setTarget(), avoiding the cost of rebuilding
* the collation-aware search object for each row.
*/
public static StringSearch getStringSearchForPattern(
final String patternString,
final int collationId) {
Collator collator = CollationFactory.fetchCollation(collationId).getCollator();
// ICU StringSearch requires a non-empty target; use a placeholder that will be replaced
// by setTarget() before each search.
return new StringSearch(patternString, new StringCharacterIterator(" "),
(RuleBasedCollator) collator);
}

/**
* Returns a collation-unaware StringSearch object for the given pattern and target strings.
* While this object does not respect collation, it can be used to find occurrences of the pattern
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
*/
package org.apache.spark.sql.catalyst.util;

import java.text.StringCharacterIterator;
import java.util.Map;
import java.util.regex.Pattern;

import com.ibm.icu.text.StringSearch;

import org.apache.spark.unsafe.types.UTF8String;

import java.util.Map;
import java.util.regex.Pattern;

/**
* Static entry point for collation-aware expressions (StringExpressions, RegexpExpressions, and
* other expressions that require custom collation support), as well as private utility methods for
Expand Down Expand Up @@ -104,6 +105,14 @@ public static boolean execICU(final UTF8String l, final UTF8String r,
StringSearch stringSearch = CollationFactory.getStringSearch(l, r, collationId);
return stringSearch.first() != StringSearch.DONE;
}
public static boolean execICU(final UTF8String l, final StringSearch stringSearch) {
if (l.numBytes() == 0) return false;
stringSearch.setTarget(new StringCharacterIterator(l.toValidString()));
return stringSearch.first() != StringSearch.DONE;
}
public static String genCode(final String l, final String cachedSearch) {
return String.format("CollationSupport.Contains.execICU(%s, %s)", l, cachedSearch);
}
}

public static class StartsWith {
Expand Down Expand Up @@ -144,6 +153,14 @@ public static boolean execICU(final UTF8String l, final UTF8String r,
StringSearch stringSearch = CollationFactory.getStringSearch(l, r, collationId);
return stringSearch.first() == 0;
}
public static boolean execICU(final UTF8String l, final StringSearch stringSearch) {
if (l.numBytes() == 0) return false;
stringSearch.setTarget(new StringCharacterIterator(l.toValidString()));
return stringSearch.first() == 0;
}
public static String genCode(final String l, final String cachedSearch) {
return String.format("CollationSupport.StartsWith.execICU(%s, %s)", l, cachedSearch);
}
}

public static class EndsWith {
Expand Down Expand Up @@ -183,6 +200,15 @@ public static boolean execICU(final UTF8String l, final UTF8String r,
int endIndex = stringSearch.getTarget().getEndIndex();
return stringSearch.last() == endIndex - stringSearch.getMatchLength();
}
public static boolean execICU(final UTF8String l, final StringSearch stringSearch) {
if (l.numBytes() == 0) return false;
stringSearch.setTarget(new StringCharacterIterator(l.toValidString()));
int endIndex = stringSearch.getTarget().getEndIndex();
return stringSearch.last() == endIndex - stringSearch.getMatchLength();
}
public static String genCode(final String l, final String cachedSearch) {
return String.format("CollationSupport.EndsWith.execICU(%s, %s)", l, cachedSearch);
}
}

public static class Upper {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import java.util.{Base64 => JBase64, HashMap, Locale, Map => JMap}

import scala.collection.mutable.ArrayBuffer

import com.ibm.icu.text.StringSearch
import org.apache.commons.text.StringEscapeUtils

import org.apache.spark.QueryContext
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -614,13 +617,86 @@ object ContainsExpressionBuilder extends StringBinaryPredicateExpressionBuilderB
}

case class Contains(left: Expression, right: Expression) extends StringPredicate {

@transient private lazy val isICUCollation: Boolean = {
val collation = CollationFactory.fetchCollation(collationId)
!collation.isUtf8BinaryType && !collation.isUtf8LcaseType
}

@transient private lazy val cachedStringSearch: StringSearch = {
if (isICUCollation && right.foldable) {
val pattern = right.eval().asInstanceOf[UTF8String]
if (pattern != null && pattern.numBytes() > 0) {
val collation = CollationFactory.fetchCollation(collationId)
val patternStr = if (collation.supportsSpaceTrimming) {
CollationFactory.applyTrimmingPolicy(pattern, collationId).toValidString()
} else {
pattern.toValidString()
}
CollationFactory.getStringSearchForPattern(patternStr, collationId)
} else null
} else null
}

override def compare(l: UTF8String, r: UTF8String): Boolean = {
CollationSupport.Contains.exec(l, r, collationId)
if (cachedStringSearch != null) {
val collation = CollationFactory.fetchCollation(collationId)
val target = if (collation.supportsSpaceTrimming) {
CollationFactory.applyTrimmingPolicy(l, collationId)
} else l
CollationSupport.Contains.execICU(target, cachedStringSearch)
} else {
CollationSupport.Contains.exec(l, r, collationId)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (isICUCollation && right.foldable) {
val rVal = right.eval()
if (rVal != null) {
val pattern = rVal.asInstanceOf[UTF8String]
if (pattern.numBytes() > 0) {
val collation = CollationFactory.fetchCollation(collationId)
val patternStr = if (collation.supportsSpaceTrimming) {
CollationFactory.applyTrimmingPolicy(pattern, collationId).toValidString()
} else {
pattern.toValidString()
}
val escapedPattern = StringEscapeUtils.escapeJava(patternStr)
val searchClass = classOf[StringSearch].getName
val factoryClass = classOf[CollationFactory].getName
val searchInit = s"""$factoryClass""" +
s""".getStringSearchForPattern(""" +
s""""$escapedPattern", $collationId)"""
val cachedSearch = ctx.addMutableState(
searchClass, "cachedStringSearch",
v => s"""$v = $searchInit;""")
val eval = left.genCode(ctx)
val targetExpr = if (collation.supportsSpaceTrimming) {
s"""$factoryClass.applyTrimmingPolicy(""" +
s"""${eval.value}, $collationId)"""
} else {
s"${eval.value}"
}
val execCode = s"""CollationSupport""" +
s""".Contains.execICU(""" +
s"""$targetExpr, $cachedSearch)"""
return ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $execCode;
}
""")
}
}
}
defineCodeGen(ctx, ev, (c1, c2) =>
CollationSupport.Contains.genCode(c1, c2, collationId))
}

override def inputTypes : Seq[AbstractDataType] =
Seq(StringTypeNonCSAICollation(supportsTrimCollation = true),
StringTypeNonCSAICollation(supportsTrimCollation = true)
Expand Down Expand Up @@ -658,11 +734,82 @@ object StartsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilde
}

case class StartsWith(left: Expression, right: Expression) extends StringPredicate {

@transient private lazy val isICUCollation: Boolean = {
val collation = CollationFactory.fetchCollation(collationId)
!collation.isUtf8BinaryType && !collation.isUtf8LcaseType
}

@transient private lazy val cachedStringSearch: StringSearch = {
if (isICUCollation && right.foldable) {
val pattern = right.eval().asInstanceOf[UTF8String]
if (pattern != null && pattern.numBytes() > 0) {
val collation = CollationFactory.fetchCollation(collationId)
val patternStr = if (collation.supportsSpaceTrimming) {
CollationFactory.applyTrimmingPolicy(pattern, collationId).toValidString()
} else {
pattern.toValidString()
}
CollationFactory.getStringSearchForPattern(patternStr, collationId)
} else null
} else null
}

override def compare(l: UTF8String, r: UTF8String): Boolean = {
CollationSupport.StartsWith.exec(l, r, collationId)
if (cachedStringSearch != null) {
val collation = CollationFactory.fetchCollation(collationId)
val target = if (collation.supportsSpaceTrimming) {
CollationFactory.applyTrimmingPolicy(l, collationId)
} else l
CollationSupport.StartsWith.execICU(target, cachedStringSearch)
} else {
CollationSupport.StartsWith.exec(l, r, collationId)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (isICUCollation && right.foldable) {
val rVal = right.eval()
if (rVal != null) {
val pattern = rVal.asInstanceOf[UTF8String]
if (pattern.numBytes() > 0) {
val collation = CollationFactory.fetchCollation(collationId)
val patternStr = if (collation.supportsSpaceTrimming) {
CollationFactory.applyTrimmingPolicy(pattern, collationId).toValidString()
} else {
pattern.toValidString()
}
val escapedPattern = StringEscapeUtils.escapeJava(patternStr)
val searchClass = classOf[StringSearch].getName
val factoryClass = classOf[CollationFactory].getName
val searchInit = s"""$factoryClass""" +
s""".getStringSearchForPattern(""" +
s""""$escapedPattern", $collationId)"""
val cachedSearch = ctx.addMutableState(
searchClass, "cachedStringSearch",
v => s"""$v = $searchInit;""")
val eval = left.genCode(ctx)
val targetExpr = if (collation.supportsSpaceTrimming) {
s"""$factoryClass.applyTrimmingPolicy(""" +
s"""${eval.value}, $collationId)"""
} else {
s"${eval.value}"
}
val execCode = s"""CollationSupport""" +
s""".StartsWith.execICU(""" +
s"""$targetExpr, $cachedSearch)"""
return ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $execCode;
}
""")
}
}
}
defineCodeGen(ctx, ev, (c1, c2) =>
CollationSupport.StartsWith.genCode(c1, c2, collationId))
}
Expand Down Expand Up @@ -707,11 +854,82 @@ object EndsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilderB
}

case class EndsWith(left: Expression, right: Expression) extends StringPredicate {

@transient private lazy val isICUCollation: Boolean = {
val collation = CollationFactory.fetchCollation(collationId)
!collation.isUtf8BinaryType && !collation.isUtf8LcaseType
}

@transient private lazy val cachedStringSearch: StringSearch = {
if (isICUCollation && right.foldable) {
val pattern = right.eval().asInstanceOf[UTF8String]
if (pattern != null && pattern.numBytes() > 0) {
val collation = CollationFactory.fetchCollation(collationId)
val patternStr = if (collation.supportsSpaceTrimming) {
CollationFactory.applyTrimmingPolicy(pattern, collationId).toValidString()
} else {
pattern.toValidString()
}
CollationFactory.getStringSearchForPattern(patternStr, collationId)
} else null
} else null
}

override def compare(l: UTF8String, r: UTF8String): Boolean = {
CollationSupport.EndsWith.exec(l, r, collationId)
if (cachedStringSearch != null) {
val collation = CollationFactory.fetchCollation(collationId)
val target = if (collation.supportsSpaceTrimming) {
CollationFactory.applyTrimmingPolicy(l, collationId)
} else l
CollationSupport.EndsWith.execICU(target, cachedStringSearch)
} else {
CollationSupport.EndsWith.exec(l, r, collationId)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (isICUCollation && right.foldable) {
val rVal = right.eval()
if (rVal != null) {
val pattern = rVal.asInstanceOf[UTF8String]
if (pattern.numBytes() > 0) {
val collation = CollationFactory.fetchCollation(collationId)
val patternStr = if (collation.supportsSpaceTrimming) {
CollationFactory.applyTrimmingPolicy(pattern, collationId).toValidString()
} else {
pattern.toValidString()
}
val escapedPattern = StringEscapeUtils.escapeJava(patternStr)
val searchClass = classOf[StringSearch].getName
val factoryClass = classOf[CollationFactory].getName
val searchInit = s"""$factoryClass""" +
s""".getStringSearchForPattern(""" +
s""""$escapedPattern", $collationId)"""
val cachedSearch = ctx.addMutableState(
searchClass, "cachedStringSearch",
v => s"""$v = $searchInit;""")
val eval = left.genCode(ctx)
val targetExpr = if (collation.supportsSpaceTrimming) {
s"""$factoryClass.applyTrimmingPolicy(""" +
s"""${eval.value}, $collationId)"""
} else {
s"${eval.value}"
}
val execCode = s"""CollationSupport""" +
s""".EndsWith.execICU(""" +
s"""$targetExpr, $cachedSearch)"""
return ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $execCode;
}
""")
}
}
}
defineCodeGen(ctx, ev, (c1, c2) =>
CollationSupport.EndsWith.genCode(c1, c2, collationId))
}
Expand Down
Loading