Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,26 @@ abstract class InMemoryBaseTable(
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
// the result should be consistent with BucketFunctions defined at transformFunctions.scala
case BucketTransform(numBuckets, cols, _) =>
val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row))
var valueHashCode = 0
valueTypePairs.foreach( pair =>
if ( pair._1 != null) valueHashCode += pair._1.hashCode()
)
var dataTypeHashCode = 0
valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets
val hash: Long = cols.foldLeft(0L) { (acc, col) =>
val valueHash = extractor(col.fieldNames, cleanedSchema, row) match {
case (value: Byte, _: ByteType) => value.toLong
case (value: Short, _: ShortType) => value.toLong
case (value: Int, _: IntegerType) => value.toLong
case (value: Long, _: LongType) => value
case (value: Long, _: TimestampType) => value
case (value: Long, _: TimestampNTZType) => value
case (value: UTF8String, _: StringType) =>
value.hashCode.toLong
case (value: Array[Byte], BinaryType) =>
util.Arrays.hashCode(value).toLong
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
acc + valueHash
}
Math.floorMod(hash, numBuckets)
case NamedTransform("truncate", Seq(ref: NamedReference, length: V2Literal[_])) =>
extractor(ref.fieldNames, cleanedSchema, row) match {
case (str: UTF8String, StringType) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ trait KeyGroupedPartitionedScan[T] {
def getOutputKeyGroupedPartitioning(
basePartitioning: KeyGroupedPartitioning,
spjParams: StoragePartitionJoinParams): KeyGroupedPartitioning = {
val expressions = spjParams.joinKeyPositions match {
val projectedExpressions = spjParams.joinKeyPositions match {
case Some(projectionPositions) =>
projectionPositions.map(i => basePartitioning.expressions(i))
case _ => basePartitioning.expressions
Expand All @@ -50,14 +50,14 @@ trait KeyGroupedPartitionedScan[T] {
case None =>
spjParams.joinKeyPositions match {
case Some(projectionPositions) => basePartitioning.partitionValues.map { r =>
val projectedRow = KeyGroupedPartitioning.project(expressions,
val projectedRow = KeyGroupedPartitioning.project(basePartitioning.expressions,
projectionPositions, r)
InternalRowComparableWrapper(projectedRow, expressions)
InternalRowComparableWrapper(projectedRow, projectedExpressions)
}.distinct.map(_.row)
case _ => basePartitioning.partitionValues
}
}
basePartitioning.copy(expressions = expressions, numPartitions = newPartValues.length,
basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))

// Has exactly one partition.
val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v)))
val partitionValues = Seq(0).map(v => InternalRow.fromSeq(Seq(v)))
checkQueryPlan(df, distribution,
physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues))
}
Expand Down Expand Up @@ -2798,8 +2798,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}

test("SPARK-54439: KeyGroupedPartitioning with transform and join key size mismatch") {
// Do not use `bucket()` in "one side partition" tests as its implementation in
// `InMemoryBaseTable` conflicts with `BucketFunction`
val items_partitions = Array(years("arrive_time"))
createTable(items, itemsColumns, items_partitions)

Expand All @@ -2823,4 +2821,42 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
}
}

test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " +
"are less than cluster keys") {
withSQLConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {

val customers_partitions = Array(identity("customer_name"), bucket(4, "customer_id"))
createTable(customers, customersColumns, customers_partitions)
sql(s"INSERT INTO testcat.ns.$customers VALUES " +
s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)")

createTable(orders, ordersColumns, Array.empty)
sql(s"INSERT INTO testcat.ns.$orders VALUES " +
s"(100.0, 1), (200.0, 1), (150.0, 2), (250.0, 2), (350.0, 2), (400.50, 3)")

val df = sql(
s"""${selectWithMergeJoinHint("c", "o")}
|customer_name, customer_age, order_amount
|FROM testcat.ns.$customers c JOIN testcat.ns.$orders o
|ON c.customer_id = o.customer_id ORDER BY c.customer_id, order_amount
|""".stripMargin)

val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.length == 1)

checkAnswer(df, Seq(
Row("aaa", 10, 100.0),
Row("aaa", 10, 200.0),
Row("bbb", 20, 150.0),
Row("bbb", 20, 250.0),
Row("bbb", 20, 350.0),
Row("ccc", 30, 400.50)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
}
}
}
Expand All @@ -55,7 +55,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).select("index", "data", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1")))
checkAnswer(query, Seq(Row(3, "c", "3/3"), Row(2, "b", "2/2"), Row(1, "a", "1/1")))
}
}
}
Expand Down Expand Up @@ -124,7 +124,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {

checkAnswer(
dfQuery,
Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))
Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))
)
}
}
Expand All @@ -134,7 +134,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
prepareTable()
checkAnswer(
spark.table(tbl).select("id", "data").select("index", "_partition"),
Seq(Row(0, "3/1"), Row(0, "0/2"), Row(0, "1/3"))
Seq(Row(0, "1/1"), Row(0, "2/2"), Row(0, "3/3"))
)
}
}
Expand All @@ -159,7 +159,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).where("id > 1").select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
checkAnswer(query, Seq(Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
}
}
}
Expand All @@ -171,7 +171,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).orderBy("id").select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
}
}
}
Expand All @@ -185,7 +185,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
.select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
}
}
}
Expand All @@ -200,7 +200,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
}

assertThrows[AnalysisException] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,15 @@ object UnboundBucketFunction extends UnboundFunction {
override def name(): String = "bucket"
}

// the result should be consistent with BucketTransform defined at InMemoryBaseTable.scala
object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] {
override def inputTypes(): Array[DataType] = Array(IntegerType, LongType)
override def resultType(): DataType = IntegerType
override def name(): String = "bucket"
override def canonicalName(): String = name()
override def toString: String = name()
override def produceResult(input: InternalRow): Int = {
(input.getLong(1) % input.getInt(0)).toInt
Math.floorMod(input.getLong(1), input.getInt(0))
}

override def reducer(
Expand Down