Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
// LocalRelation and does not trigger many rules.
Batch("LocalRelation early", fixedPoint,
ConvertToLocalRelation,
FoldInnerJoinWithOneRowRelation,
PropagateEmptyRelation,
// PropagateEmptyRelation can change the nullability of an attribute from nullable to
// non-nullable when an empty relation child of a Union is removed
Expand Down Expand Up @@ -261,6 +262,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
ReassignLambdaVariableID),
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation,
FoldInnerJoinWithOneRowRelation,
PropagateEmptyRelation,
// PropagateEmptyRelation can change the nullability of an attribute from nullable to
// non-nullable when an empty relation child of a Union is removed
Expand Down Expand Up @@ -2659,11 +2661,90 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {
val predicate = Predicate.create(condition, output)
predicate.initialize(0)
LocalRelation(output, data.filter(row => predicate.eval(row)), isStreaming, stream)

// SPARK-57039: Inner join with a single-row LocalRelation (e.g. INNER JOIN VALUES (...)) can be
// folded into a Project that materializes the row's columns as constants on the other side,
// optionally followed by a Filter for the join condition. This eliminates an otherwise
// unavoidable BroadcastNestedLoopJoin/BroadcastHashJoin for a 1-row build side.
// Fills the TODO at DecorrelateInnerQuery.scala (search: "more general rule to optimize
// join with OneRowRelation").
case Join(LocalRelation(lOut, lData, false, _), right, Inner, condition, JoinHint.NONE)
if lData.length == 1 && !condition.exists(hasUnevaluableExpr) &&
!isSingleRowLeaf(right) =>
foldSingleRowJoin(lOut, lData.head, leftIsSingleRow = true, right, condition)

case Join(left, LocalRelation(rOut, rData, false, _), Inner, condition, JoinHint.NONE)
if rData.length == 1 && !condition.exists(hasUnevaluableExpr) &&
!isSingleRowLeaf(left) =>
foldSingleRowJoin(rOut, rData.head, leftIsSingleRow = false, left, condition)
}

def hasUnevaluableExpr(expr: Expression): Boolean = {
expr.exists(e => e.isInstanceOf[Unevaluable] && !e.isInstanceOf[AttributeReference])
}

// SPARK-57039: When BOTH sides of an Inner join are statically-single-row, folding the
// join would hide a cartesian product from CheckCartesianProducts and silently bypass the
// spark.sql.crossJoin.enabled=false guardrail. Leave those for the cartesian check.
//
// Use maxRows (not leaf-only pattern match): the optimizer is iterative and when this rule
// fires one side may still be wrapped (e.g. Project(LocalRelation(1 row))) while the other has
// already collapsed to a bare LocalRelation. maxRows transparently propagates through
// Project/Filter/Limit/etc., so the guard stays accurate regardless of which side reduces
// first. See SPARK-33100 CliSuite (t1(1 row) JOIN t2(1 row)) which exercised this race in r5.
def isSingleRowLeaf(plan: LogicalPlan): Boolean = plan.maxRows.contains(1L)

/**
* Build a `Project(...) [Filter(condition)]` plan that is logically equivalent to
* `Inner Join other (single-row LocalRelation) ON condition`, with the single-row side
* materialized as `Literal` columns.
*
* - Preserves the `exprId` of each output attribute of the single-row side so that any
* reference above the join continues to resolve.
* - Uses `Literal.create` so the literal carries the correct nullable + complex-type metadata.
* - Preserves the original join's output ordering: `left.output ++ right.output`.
*/
private def foldSingleRowJoin(
singleRowOutput: Seq[Attribute],
row: org.apache.spark.sql.catalyst.InternalRow,
leftIsSingleRow: Boolean,
other: LogicalPlan,
condition: Option[Expression]): LogicalPlan = {
val literals = singleRowOutput.zipWithIndex.map { case (attr, i) =>
Alias(Literal.create(row.get(i, attr.dataType), attr.dataType), attr.name)(attr.exprId)
}
val outputList = if (leftIsSingleRow) literals ++ other.output else other.output ++ literals
val projected = Project(outputList, other)
condition.map(Filter(_, projected)).getOrElse(projected)
}
}

/**
* SPARK-57039: Folds `Inner Join other (OneRowRelation) [ON condition]` into
* `[Filter(condition)](other)`.
*
* `OneRowRelation` is a zero-column, single-row leaf (e.g. produced by a `SELECT <consts>` after
* the projected literals have been folded away by constant propagation). An Inner join against it
* is logically a no-op on the row count of the other side, so it can be replaced by the other
* side directly. If the join carries a condition, the condition is preserved as a Filter on top.
*
* Complements [[ConvertToLocalRelation]]'s case 5 (single-row `LocalRelation`). Kept as a separate
* rule because the tree-pattern pruning is different: `OneRowRelation` does not register a
* `LOCAL_RELATION` pattern.
*/
object FoldInnerJoinWithOneRowRelation extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(INNER_LIKE_JOIN), ruleId) {
case Join(left, _: OneRowRelation, Inner, condition, JoinHint.NONE)
if !condition.exists(ConvertToLocalRelation.hasUnevaluableExpr) &&
!ConvertToLocalRelation.isSingleRowLeaf(left) =>
condition.map(Filter(_, left)).getOrElse(left)

case Join(_: OneRowRelation, right, Inner, condition, JoinHint.NONE)
if !condition.exists(ConvertToLocalRelation.hasUnevaluableExpr) &&
!ConvertToLocalRelation.isSingleRowLeaf(right) =>
condition.map(Filter(_, right)).getOrElse(right)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding" ::
"org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" ::
"org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation" ::
"org.apache.spark.sql.catalyst.optimizer.FoldInnerJoinWithOneRowRelation" ::
"org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" ::
"org.apache.spark.sql.catalyst.optimizer.DecimalAggregates" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateAggregateFilter" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, LessThan, Literal, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, GenericInternalRow, LessThan, Literal, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{DataType, StructType}

Expand Down Expand Up @@ -87,6 +88,119 @@ class ConvertToLocalRelationSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

// ---- SPARK-57039: Inner Join + single-row LocalRelation -> Project [Filter] ----

private val tbl = LocalRelation($"a".int, $"b".int)
private val singleRow = LocalRelation(
LocalRelation($"c1".int, $"c2".boolean).output,
InternalRow(1, true) :: Nil)
private val multiRow = LocalRelation(
LocalRelation($"c1".int, $"c2".boolean).output,
InternalRow(1, true) :: InternalRow(2, false) :: Nil)

test("SPARK-57039: InnerJoin with single-row LocalRelation on right -> Project") {
val plan = tbl.join(singleRow, Inner, None).analyze
val optimized = Optimize.execute(plan)
val c1Attr = singleRow.output(0)
val c2Attr = singleRow.output(1)
val expected = Optimize.execute(tbl.select(
$"a", $"b",
Alias(Literal.create(1, c1Attr.dataType), c1Attr.name)(c1Attr.exprId),
Alias(Literal.create(true, c2Attr.dataType), c2Attr.name)(c2Attr.exprId)).analyze)
comparePlans(optimized, expected)
}

test("SPARK-57039: InnerJoin with single-row LocalRelation on left -> Project") {
val plan = singleRow.join(tbl, Inner, None).analyze
val optimized = Optimize.execute(plan)
val c1Attr = singleRow.output(0)
val c2Attr = singleRow.output(1)
val expected = Optimize.execute(tbl.select(
Alias(Literal.create(1, c1Attr.dataType), c1Attr.name)(c1Attr.exprId),
Alias(Literal.create(true, c2Attr.dataType), c2Attr.name)(c2Attr.exprId),
$"a", $"b").analyze)
comparePlans(optimized, expected)
}

test("SPARK-57039: InnerJoin with single-row LocalRelation + condition -> Project + Filter") {
val plan = tbl.join(singleRow, Inner, Some($"a" > UnresolvedAttribute("c1"))).analyze
val optimized = Optimize.execute(plan)
val c1Attr = singleRow.output(0)
val c2Attr = singleRow.output(1)
val expected = Optimize.execute(tbl.select(
$"a", $"b",
Alias(Literal.create(1, c1Attr.dataType), c1Attr.name)(c1Attr.exprId),
Alias(Literal.create(true, c2Attr.dataType), c2Attr.name)(c2Attr.exprId))
.where($"a" > UnresolvedAttribute("c1")).analyze)
comparePlans(optimized, expected)
}

test("SPARK-57039: do NOT fold non-Inner join") {
val plan = tbl.join(singleRow, LeftOuter, None).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, plan)
}

test("SPARK-57039: do NOT fold multi-row LocalRelation") {
val plan = tbl.join(multiRow, Inner, None).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, plan)
}

test("SPARK-57039: preserve exprId across single-row fold") {
val plan = tbl.join(singleRow, Inner, None).analyze
val optimized = Optimize.execute(plan)
val outIds = optimized.output.map(_.exprId).toSet
assert(outIds.contains(singleRow.output(0).exprId))
assert(outIds.contains(singleRow.output(1).exprId))
}

test("SPARK-57039: fold actually removes Join and preserves singleRow exprIds in literals") {
val plan = tbl.join(singleRow, Inner, None).analyze
val optimized = Optimize.execute(plan)
// Strong assertion #1: fold actually happened - no Join survives
assert(optimized.collectFirst { case _: Join => () }.isEmpty,
s"Expected Join to be folded away, got: ")
// Strong assertion #2: output schema width = tbl + singleRow columns (4 cols total)
assert(optimized.output.length == 4,
s"Expected 4-column output after fold, got: ")
// Strong assertion #3: singleRow exprIds preserved through fold
val outIds = optimized.output.map(_.exprId).toSet
assert(outIds.contains(singleRow.output(0).exprId),
"singleRow col 0 exprId should survive fold via Alias-preserved exprId")
assert(outIds.contains(singleRow.output(1).exprId),
"singleRow col 1 exprId should survive fold via Alias-preserved exprId")
}

test("SPARK-57039: do NOT fold when both sides are single-row LocalRelation (cartesian)") {
val l = LocalRelation.fromExternalRows(
Seq(Symbol("a").int, Symbol("b").int), Seq(Row(1, 2)))
val r = LocalRelation.fromExternalRows(
Seq(Symbol("c").int, Symbol("d").int), Seq(Row(3, 4)))
val plan = l.join(r, Inner, None).analyze
val optimized = Optimize.execute(plan)
// Must NOT collapse: CheckCartesianProducts must still see the Join.
assert(optimized.collectFirst { case _: Join => () }.isDefined,
"Join should be preserved when both sides are single-row")
}

test("SPARK-57039: do NOT fold when other side is Project over single-row LocalRelation") {
// SPARK-33100 CliSuite regression: "CREATE TEMPORARY VIEW t AS SELECT * FROM VALUES(...)"
// parses to Project(LocalRelation(1 row)); during optimization one side may already have
// collapsed to a bare LR while the other still wears a Project wrapper. A leaf-only
// single-row guard misses this and folds the join, hiding a 1x1 cartesian.
val l = LocalRelation.fromExternalRows(
Seq(Symbol("a").int, Symbol("b").int), Seq(Row(1, 2)))
val r = LocalRelation.fromExternalRows(
Seq(Symbol("c").int, Symbol("d").int), Seq(Row(3, 4)))
// Wrap r in a Project so r itself is not a bare LocalRelation pattern match.
val rProj = r.select(Symbol("c"), Symbol("d"))
val plan = l.join(rProj, Inner, None).analyze
val optimized = Optimize.execute(plan)
assert(optimized.collectFirst { case _: Join => () }.isDefined,
"Join must be preserved when other side is Project(single-row LR)")
}
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class FoldInnerJoinWithOneRowRelationSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("FoldInnerJoinWithOneRowRelation", FixedPoint(10),
FoldInnerJoinWithOneRowRelation) :: Nil
}

private val tbl = LocalRelation($"a".int, $"b".int)

test("SPARK-57039: fold Inner join with OneRowRelation on the right (no condition)") {
val plan = tbl.join(OneRowRelation(), Inner, None).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, tbl.analyze)
}

test("SPARK-57039: fold Inner join with OneRowRelation on the left (no condition)") {
val plan = OneRowRelation().join(tbl, Inner, None).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, tbl.analyze)
}

test("SPARK-57039: fold Inner join with OneRowRelation and a condition into a Filter") {
val plan = tbl.join(OneRowRelation(), Inner, Some($"a" > 5)).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, tbl.where($"a" > 5).analyze)
}

test("SPARK-57039: do NOT fold non-Inner join with OneRowRelation") {
val plan = tbl.join(OneRowRelation(), LeftOuter, None).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, plan)
}

test("SPARK-57039: fold join with ArrayType column on the kept side") {
val arrTbl = LocalRelation(Symbol("a").array(IntegerType), Symbol("b").int)
val plan = arrTbl.join(OneRowRelation(), Inner, None).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, arrTbl.analyze)
}

test("SPARK-57039: fold join with MapType column on the kept side") {
val mapTbl = LocalRelation(Symbol("m").map(StringType, IntegerType), Symbol("b").int)
val plan = mapTbl.join(OneRowRelation(), Inner, None).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, mapTbl.analyze)
}

test("SPARK-57039: fold join with StructType column on the kept side") {
import org.apache.spark.sql.catalyst.expressions.AttributeReference
val sAttr = AttributeReference("s",
StructType(StructField("x", IntegerType) :: StructField("y", StringType) :: Nil))()
val structTbl = LocalRelation(sAttr, Symbol("b").int)
val plan = structTbl.join(OneRowRelation(), Inner, None).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, structTbl.analyze)
}

test("SPARK-57039: nested fold - Join(Join(tbl, OneRow), OneRow) collapses to tbl") {
val plan = tbl.join(OneRowRelation(), Inner, None).join(OneRowRelation(), Inner, None).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, tbl.analyze)
}

test("SPARK-57039: do NOT fold when condition contains an unevaluable expression") {
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery
val subq = ScalarSubquery(tbl.select(Symbol("a")).analyze)
val plan = tbl.join(OneRowRelation(), Inner, Some(Symbol("a") === subq)).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, plan)
}

test("SPARK-57039: do NOT fold when both sides are OneRowRelation (cartesian product)") {
val plan = OneRowRelation().join(OneRowRelation(), Inner, None).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, plan)
}

test("SPARK-57039: do NOT fold when OneRowRelation joins a single-row LocalRelation") {
val singleRow = LocalRelation.fromExternalRows(
Seq(Symbol("x").int, Symbol("y").string), Seq(Row(1, "a")))
val plan = singleRow.join(OneRowRelation(), Inner, None).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, plan)
}
}