Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ object FlattenSequentialStreamingUnion extends Rule[LogicalPlan] {
}

/**
* Validates SequentialStreamingUnion constraints:
* Validates SequentialStreamingUnion constraints during analysis:
* - All children must be streaming relations
* - No nested SequentialStreamingUnions (should be flattened first)
* - No stateful operations in any child subtrees
*
* Note: Minimum 2 children is enforced by the resolved property, not explicit validation.
* Note: Nesting validation happens after optimization (see
* ValidateSequentialStreamingUnionNesting).
*/
object ValidateSequentialStreamingUnion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.foreach {
case su: SequentialStreamingUnion =>
validateAllStreaming(su)
validateNoNesting(su)
validateNoStatefulDescendants(su)
case _ =>
}
Expand All @@ -62,14 +62,6 @@ object ValidateSequentialStreamingUnion extends Rule[LogicalPlan] {
}
}

private def validateNoNesting(su: SequentialStreamingUnion): Unit = {
su.children.foreach { child =>
if (child.containsPattern(SEQUENTIAL_STREAMING_UNION)) {
throw QueryCompilationErrors.nestedSequentialStreamingUnionError()
}
}
}

private def validateNoStatefulDescendants(su: SequentialStreamingUnion): Unit = {
su.children.foreach { child =>
if (child.exists(UnsupportedOperationChecker.isStatefulOperation)) {
Expand All @@ -78,3 +70,25 @@ object ValidateSequentialStreamingUnion extends Rule[LogicalPlan] {
}
}
}

/**
* Validates that SequentialStreamingUnion nodes have no nesting after optimization.
* This runs as a post-optimization check to ensure the optimizer has properly flattened
* all nested SequentialStreamingUnions (including those wrapped in stateless operations).
*
* Runs after CombineUnions has flattened nested unions.
*/
object ValidateSequentialStreamingUnionNesting extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.foreach {
case su: SequentialStreamingUnion =>
su.children.foreach { child =>
if (child.containsPattern(SEQUENTIAL_STREAMING_UNION)) {
throw QueryCompilationErrors.nestedSequentialStreamingUnionError()
}
}
case _ =>
}
plan
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,11 @@ abstract class Optimizer(catalogManager: CatalogManager)
RemoveNoopOperators),
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers),
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)))
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression),
// Validate that nested SequentialStreamingUnions have been properly flattened.
// Must run after CombineUnions in the "Union" batch.
Batch("Validate SequentialStreamingUnion", Once,
ValidateSequentialStreamingUnionNesting)))

// remove any batches with no rules. this may happen when subclasses do not add optional rules.
batches.filter(_.rules.nonEmpty)
Expand Down Expand Up @@ -1018,7 +1022,9 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] {
project.outputSet.size == project.projectList.size
}

def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): Seq[LogicalPlan] = {
def pushProjectionThroughUnion(
projectList: Seq[NamedExpression],
u: UnionBase): Seq[LogicalPlan] = {
val newFirstChild = Project(projectList, u.children.head)
val newOtherChildren = u.children.tail.map { child =>
val rewrites = buildRewrites(u.children.head, child)
Expand All @@ -1028,13 +1034,15 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] {
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAllPatterns(UNION, PROJECT)) {
t => t.containsPattern(PROJECT) &&
t.containsAnyPattern(UNION, SEQUENTIAL_STREAMING_UNION)) {

// Push down deterministic projection through UNION ALL
case project @ Project(projectList, u: Union)
// Push down deterministic projection through Union or SequentialStreamingUnion.
// This is safe because it preserves child ordering.
case project @ Project(projectList, SequentialOrSimpleUnion(u))
if projectList.forall(_.deterministic) && u.children.nonEmpty &&
canPushProjectionThroughUnion(project) =>
u.copy(children = pushProjectionThroughUnion(projectList, u))
SequentialOrSimpleUnion.withNewChildren(u, pushProjectionThroughUnion(projectList, u))
}
}

Expand Down Expand Up @@ -1814,20 +1822,30 @@ object CombineUnions extends Rule[LogicalPlan] {
import PushProjectionThroughUnion.{canPushProjectionThroughUnion, pushProjectionThroughUnion}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
_.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) {
case u: Union => flattenUnion(u, false)
case Distinct(u: Union) => Distinct(flattenUnion(u, true))
_.containsAnyPattern(UNION, SEQUENTIAL_STREAMING_UNION, DISTINCT_LIKE), ruleId) {
// Flatten Union or SequentialStreamingUnion.
// This is safe because flattening preserves child ordering.
case SequentialOrSimpleUnion(u) => flattenUnion(u, false)
case Distinct(SequentialOrSimpleUnion(u)) => Distinct(flattenUnion(u, true))
// Only handle distinct-like 'Deduplicate', where the keys == output
case Deduplicate(keys: Seq[Attribute], u: Union) if AttributeSet(keys) == u.outputSet =>
case Deduplicate(keys: Seq[Attribute], SequentialOrSimpleUnion(u))
if AttributeSet(keys) == u.outputSet =>
Deduplicate(keys, flattenUnion(u, true))
case DeduplicateWithinWatermark(keys: Seq[Attribute], u: Union)
case DeduplicateWithinWatermark(keys: Seq[Attribute], SequentialOrSimpleUnion(u))
if AttributeSet(keys) == u.outputSet =>
DeduplicateWithinWatermark(keys, flattenUnion(u, true))
}

private def flattenUnion(union: Union, flattenDistinct: Boolean): Union = {
val topByName = union.byName
val topAllowMissingCol = union.allowMissingCol
private def flattenUnion(union: UnionBase, flattenDistinct: Boolean): UnionBase = {
val topByName = SequentialOrSimpleUnion.byName(union)
val topAllowMissingCol = SequentialOrSimpleUnion.allowMissingCol(union)

// Helper to check if a union can be merged with the top-level union
def canMerge(u: UnionBase): Boolean = {
SequentialOrSimpleUnion.isSameType(union, u) &&
SequentialOrSimpleUnion.byName(u) == topByName &&
SequentialOrSimpleUnion.allowMissingCol(u) == topAllowMissingCol
}

val stack = mutable.Stack[LogicalPlan](union)
val flattened = mutable.ArrayBuffer.empty[LogicalPlan]
Expand All @@ -1843,38 +1861,35 @@ object CombineUnions extends Rule[LogicalPlan] {
!p2.projectList.exists(SubqueryExpression.hasCorrelatedSubquery) =>
val newProjectList = buildCleanedProjectList(p1.projectList, p2.projectList)
stack.pushAll(Seq(p2.copy(projectList = newProjectList)))
case Distinct(Union(children, byName, allowMissingCol))
if flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(children.reverse)
case Distinct(SequentialOrSimpleUnion(u)) if flattenDistinct && canMerge(u) =>
stack.pushAll(u.children.reverse)
// Only handle distinct-like 'Deduplicate', where the keys == output
case Deduplicate(keys: Seq[Attribute], u: Union)
if flattenDistinct && u.byName == topByName &&
u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == u.outputSet =>
case Deduplicate(keys: Seq[Attribute], SequentialOrSimpleUnion(u))
if flattenDistinct && canMerge(u) && AttributeSet(keys) == u.outputSet =>
stack.pushAll(u.children.reverse)
case Union(children, byName, allowMissingCol)
if byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(children.reverse)
// Push down projection through Union and then push pushed plan to Stack if
case SequentialOrSimpleUnion(u) if canMerge(u) =>
stack.pushAll(u.children.reverse)
// Push down projection through union and then push pushed plan to Stack if
// there is a Project.
case project @ Project(projectList, Distinct(u @ Union(children, byName, allowMissingCol)))
if projectList.forall(_.deterministic) && children.nonEmpty &&
flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol &&
case project @ Project(projectList, Distinct(SequentialOrSimpleUnion(u)))
if projectList.forall(_.deterministic) && u.children.nonEmpty &&
flattenDistinct && canMerge(u) &&
canPushProjectionThroughUnion(project) =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
case project @ Project(projectList, Deduplicate(keys: Seq[Attribute], u: Union))
if projectList.forall(_.deterministic) && flattenDistinct && u.byName == topByName &&
u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == u.outputSet &&
canPushProjectionThroughUnion(project) =>
case project @ Project(
projectList, Deduplicate(keys: Seq[Attribute], SequentialOrSimpleUnion(u)))
if projectList.forall(_.deterministic) && flattenDistinct && canMerge(u) &&
AttributeSet(keys) == u.outputSet && canPushProjectionThroughUnion(project) =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
case project @ Project(projectList, u @ Union(children, byName, allowMissingCol))
if projectList.forall(_.deterministic) && children.nonEmpty && byName == topByName &&
allowMissingCol == topAllowMissingCol && canPushProjectionThroughUnion(project) =>
case project @ Project(projectList, SequentialOrSimpleUnion(u))
if projectList.forall(_.deterministic) && u.children.nonEmpty &&
canMerge(u) && canPushProjectionThroughUnion(project) =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
case child =>
flattened += child
}
}
union.copy(children = flattened.toSeq)
SequentialOrSimpleUnion.withNewChildren(union, flattened.toSeq)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
* 2. Second child begins processing
* 3. And so on...
*
* IMPORTANT: Child ordering IS semantically significant. Children are processed sequentially
* in the exact order specified. Optimizer rules must preserve this ordering.
*
* Requirements:
* - Minimum 2 children required
* - All children must be streaming sources
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,9 +566,62 @@ abstract class UnionBase extends LogicalPlan {
}
}

/**
* Extractor and helper methods for Union and SequentialStreamingUnion.
* Does not match other UnionBase subtypes like UnionLoop.
*/
object SequentialOrSimpleUnion {
/**
* Extractor that matches Union and SequentialStreamingUnion for optimizer rules.
*/
def unapply(plan: LogicalPlan): Option[UnionBase] = plan match {
case u: Union => Some(u)
case u: SequentialStreamingUnion => Some(u)
case _ => None
}

/**
* Returns true if both unions are the same concrete type.
* Used during flattening to ensure Union and SequentialStreamingUnion are not merged.
*/
def isSameType(u1: UnionBase, u2: UnionBase): Boolean = (u1, u2) match {
case (_: Union, _: Union) => true
case (_: SequentialStreamingUnion, _: SequentialStreamingUnion) => true
case _ => false
}

/**
* Extracts byName flag from Union or SequentialStreamingUnion.
*/
def byName(u: UnionBase): Boolean = u match {
case union: Union => union.byName
case ssu: SequentialStreamingUnion => ssu.byName
}

/**
* Extracts allowMissingCol flag from Union or SequentialStreamingUnion.
*/
def allowMissingCol(u: UnionBase): Boolean = u match {
case union: Union => union.allowMissingCol
case ssu: SequentialStreamingUnion => ssu.allowMissingCol
}

/**
* Creates a new union of the same type with the specified children.
*/
def withNewChildren(u: UnionBase, newChildren: Seq[LogicalPlan]): UnionBase = u match {
case union: Union => union.copy(children = newChildren)
case ssu: SequentialStreamingUnion => ssu.copy(children = newChildren)
}
}

/**
* Logical plan for unioning multiple plans, without a distinct. This is UNION ALL in SQL.
*
* NOTE: Child ordering is NOT semantically significant. Children are processed in parallel
* and their order does not affect the result. This allows Union-specific optimizations to
* reorder children (e.g., for performance), unlike SequentialStreamingUnion where order matters.
*
* @param byName Whether resolves columns in the children by column names.
* @param allowMissingCol Allows missing columns in children query plans. If it is true,
* this function allows different set of column names between two Datasets.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class SequentialStreamingUnionAnalysisSuite extends AnalysisTest with DataTypeEr
parameters = Map("operator" -> "SequentialStreamingUnion"))
}

test("ValidateSequentialStreamingUnion - rejects directly nested unions") {
test("ValidateSequentialStreamingUnionNesting - rejects directly nested unions") {
val streamingRelation1 = testRelation1.copy(isStreaming = true)
val streamingRelation2 = testRelation2.copy(isStreaming = true)
val streamingRelation3 = testRelation3.copy(isStreaming = true)
Expand All @@ -146,16 +146,17 @@ class SequentialStreamingUnionAnalysisSuite extends AnalysisTest with DataTypeEr
val innerUnion = SequentialStreamingUnion(streamingRelation1, streamingRelation2)
val outerUnion = SequentialStreamingUnion(innerUnion, streamingRelation3)

// Note: This validation now runs AFTER optimizer flattening
checkError(
exception = intercept[AnalysisException] {
ValidateSequentialStreamingUnion(outerUnion)
ValidateSequentialStreamingUnionNesting(outerUnion)
},
condition = "NESTED_SEQUENTIAL_STREAMING_UNION",
parameters = Map(
"hint" -> "Use chained followedBy calls instead: df1.followedBy(df2).followedBy(df3)"))
}

test("ValidateSequentialStreamingUnion - rejects nested unions through other operators") {
test("ValidateSequentialStreamingUnionNesting - rejects nested unions through other operators") {
val streamingRelation1 = testRelation1.copy(isStreaming = true)
val streamingRelation2 = testRelation2.copy(isStreaming = true)
val streamingRelation3 = testRelation3.copy(isStreaming = true)
Expand All @@ -166,9 +167,10 @@ class SequentialStreamingUnionAnalysisSuite extends AnalysisTest with DataTypeEr
val projectOverUnion = Project(Seq($"a", $"b"), innerUnion)
val outerUnion = SequentialStreamingUnion(projectOverUnion, streamingRelation3)

// Note: This validation now runs AFTER optimizer flattening
checkError(
exception = intercept[AnalysisException] {
ValidateSequentialStreamingUnion(outerUnion)
ValidateSequentialStreamingUnionNesting(outerUnion)
},
condition = "NESTED_SEQUENTIAL_STREAMING_UNION",
parameters = Map(
Expand Down
Loading