Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ object AttributeSet {
* `AttributeReference("a"...) == AttributeReference("b", ...)`. This tactic leads to broken tests,
* and also makes doing transformations hard (we always try keep older trees instead of new ones
* when the transformation was a no-op).
*
* Iteration via [[iterator]], [[foreach]], or [[Iterable]]-derived combinators (`flatMap`, etc.)
* visits elements in insertion order. Note: [[toSeq]] is an explicit exception -- it sorts by
* `(name, exprId.id)` for stable codegen output, see SPARK-18394.
*/
class AttributeSet private (private val baseSet: mutable.LinkedHashSet[AttributeEquals])
extends Iterable[Attribute] with Serializable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,13 +578,13 @@ case class KeyedPartitioning(
c.areAllClusterKeysMatched(expressions)
} else {
// We'll need to find leaf attributes from the partition expressions first.
lazy val attributes = expressions.flatMap(_.collectLeaves())
lazy val attributes = AttributeSet.fromAttributeSets(expressions.map(_.references))

if (SQLConf.get.v2BucketingAllowKeysSubsetOfPartitionKeys) {
// check that operation keys (required clustering keys)
// overlap with partition keys (KeyedPartitioning attributes)
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
expressions.forall(_.references.size == 1)
} else if (isNarrowed && !isGrouped) {
// A narrowed, non-grouped partitioning carries the same skew risk as using a subset of
// partition keys for a join: GroupPartitionsExec will merge partitions that held
Expand Down Expand Up @@ -1218,9 +1218,9 @@ case class KeyedShuffleSpec(
distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
}
partitioning.expressions.map { e =>
val leaves = e.collectLeaves()
assert(leaves.size == 1, s"Expected exactly one child from $e, but found ${leaves.size}")
distKeyToPos.getOrElse(leaves.head.canonicalized, mutable.BitSet.empty)
val refs = e.references
assert(refs.size == 1, s"Expected exactly one child from $e, but found ${refs.size}")
distKeyToPos.getOrElse(refs.head.canonicalized, mutable.BitSet.empty)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ case class EnsureRequirements(
// Find any KeyedPartitioning that satisfies via groupedSatisfies.
val satisfyingKeyedPartitioning =
groupedSatisfies.orElse(nonGroupedSatisfiesWhenGrouped).get
val attrs = satisfyingKeyedPartitioning.expressions.flatMap(_.collectLeaves())
.map(_.asInstanceOf[Attribute])
// The single-column invariant in KeyedPartitioning.supportsExpressions guarantees
// one attribute per partition expression.
val attrs = satisfyingKeyedPartitioning.expressions.flatMap(_.references)
val keyRowOrdering = RowOrdering.create(o.ordering, attrs)
val keyOrdering = keyRowOrdering.on((t: InternalRowComparableWrapper) => t.row)
if (satisfyingKeyedPartitioning.partitionKeys.sliding(2).forall {
Expand Down Expand Up @@ -409,12 +410,16 @@ case class EnsureRequirements(
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
case (Some(KeyedPartitioning(clustering, _, _, _)), _) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
// The single-column invariant in KeyedPartitioning.supportsExpressions guarantees one
// attribute per partition expression.
val leafExprs = clustering.flatMap(_.references)
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(KeyedPartitioning(clustering, _, _, _))) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
// The single-column invariant in KeyedPartitioning.supportsExpressions guarantees one
// attribute per partition expression.
val leafExprs = clustering.flatMap(_.references)
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
Expand Down Expand Up @@ -777,7 +782,9 @@ case class EnsureRequirements(
partitioning: Partitioning,
distribution: ClusteredDistribution): Option[KeyedShuffleSpec] = {
def tryCreate(partitioning: KeyedPartitioning): Option[KeyedShuffleSpec] = {
val attributes = partitioning.expressions.flatMap(_.collectLeaves())
// The single-column invariant in KeyedPartitioning.supportsExpressions guarantees one
// attribute per partition expression.
val attributes = partitioning.expressions.flatMap(_.references)
val clustering = distribution.clustering

val satisfies = if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class EnsureRequirementsSuite extends SharedSparkSession {
private val exprA = Literal(1)
private val exprB = Literal(2)
private val exprC = Literal(3)
private val exprD = Literal(4)
private val exprA = AttributeReference("a", IntegerType)()
private val exprB = AttributeReference("b", IntegerType)()
private val exprC = AttributeReference("c", IntegerType)()
private val exprD = AttributeReference("d", IntegerType)()

private val EnsureRequirements = new EnsureRequirements()

Expand Down