Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.auron.AuronConverters.ForceNativeExecutionWrapperBase
import org.apache.spark.sql.auron.NativeConverters.NativeExprWrapperBase
import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft, JoinBuildRight, JoinBuildSide}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
Expand Down Expand Up @@ -96,6 +97,7 @@ import org.apache.spark.sql.execution.auron.plan.NativeWindowExec
import org.apache.spark.sql.execution.auron.shuffle.{AuronBlockStoreShuffleReaderBase, AuronRssShuffleManagerBase, RssPartitionWriterBase}
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec}
import org.apache.spark.sql.execution.joins.auron.plan.NativeBroadcastJoinExec
import org.apache.spark.sql.execution.joins.auron.plan.NativeShuffledHashJoinExecProvider
import org.apache.spark.sql.execution.joins.auron.plan.NativeSortMergeJoinExecProvider
Expand Down Expand Up @@ -229,7 +231,7 @@ class ShimsImpl extends Shims with Logging {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
broadcastSide: BroadcastSide,
broadcastSide: JoinBuildSide,
isNullAwareAntiJoin: Boolean): NativeBroadcastJoinBase =
NativeBroadcastJoinExec(
left,
Expand Down Expand Up @@ -262,7 +264,7 @@ class ShimsImpl extends Shims with Logging {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
buildSide: BuildSide,
buildSide: JoinBuildSide,
isSkewJoin: Boolean): SparkPlan =
NativeShuffledHashJoinExecProvider.provide(
left,
Expand Down Expand Up @@ -1037,6 +1039,42 @@ class ShimsImpl extends Shims with Logging {
override def getAdaptiveInputPlan(exec: AdaptiveSparkPlanExec): SparkPlan = {
exec.inputPlan
}

private def convertJoinBuildSide(
exec: SparkPlan,
isBuildLeft: Any => Boolean): JoinBuildSide = {
exec match {
case shj: ShuffledHashJoinExec =>
if (isBuildLeft(shj.buildSide)) JoinBuildLeft else JoinBuildRight
case bhj: BroadcastHashJoinExec =>
if (isBuildLeft(bhj.buildSide)) JoinBuildLeft else JoinBuildRight
case bnlj: BroadcastNestedLoopJoinExec =>
if (isBuildLeft(bnlj.buildSide)) JoinBuildLeft else JoinBuildRight
case other => throw new IllegalArgumentException(s"Unsupported SparkPlan type: $other")
}
}

@sparkver("3.0")
override def getJoinBuildSide(exec: SparkPlan): JoinBuildSide = {
import org.apache.spark.sql.execution.joins.BuildLeft
convertJoinBuildSide(
exec,
isBuildLeft = {
case BuildLeft => true
case _ => false
})
}

@sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override def getJoinBuildSide(exec: SparkPlan): JoinBuildSide = {
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
convertJoinBuildSide(
exec,
isBuildLeft = {
case BuildLeft => true
case _ => false
})
}
}

case class ForceNativeExecutionWrapper(override val child: SparkPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
*/
package org.apache.spark.sql.execution.joins.auron.plan

import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft, JoinBuildRight, JoinBuildSide}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.auron.plan.BroadcastLeft
import org.apache.spark.sql.execution.auron.plan.BroadcastRight
import org.apache.spark.sql.execution.auron.plan.BroadcastSide
import org.apache.spark.sql.execution.auron.plan.NativeBroadcastJoinBase
import org.apache.spark.sql.execution.joins.HashJoin

Expand All @@ -35,7 +33,7 @@ case class NativeBroadcastJoinExec(
override val leftKeys: Seq[Expression],
override val rightKeys: Seq[Expression],
override val joinType: JoinType,
broadcastSide: BroadcastSide,
broadcastSide: JoinBuildSide,
isNullAwareAntiJoin: Boolean)
extends NativeBroadcastJoinBase(
left,
Expand All @@ -53,14 +51,14 @@ case class NativeBroadcastJoinExec(
@sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override def buildSide: org.apache.spark.sql.catalyst.optimizer.BuildSide =
broadcastSide match {
case BroadcastLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
case BroadcastRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
case JoinBuildLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
case JoinBuildRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
}

@sparkver("3.0")
override val buildSide: org.apache.spark.sql.execution.joins.BuildSide = broadcastSide match {
case BroadcastLeft => org.apache.spark.sql.execution.joins.BuildLeft
case BroadcastRight => org.apache.spark.sql.execution.joins.BuildRight
case JoinBuildLeft => org.apache.spark.sql.execution.joins.BuildLeft
case JoinBuildRight => org.apache.spark.sql.execution.joins.BuildRight
}

@sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
Expand All @@ -71,9 +69,9 @@ case class NativeBroadcastJoinExec(

def mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAware = false)
broadcastSide match {
case BroadcastLeft =>
case JoinBuildLeft =>
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
case BroadcastRight =>
case JoinBuildRight =>
UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
*/
package org.apache.spark.sql.execution.joins.auron.plan

import org.apache.spark.sql.auron.join.JoinBuildSides.JoinBuildSide
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.auron.plan.BuildSide
import org.apache.spark.sql.execution.auron.plan.NativeShuffledHashJoinBase
import org.apache.spark.sql.execution.joins.HashJoin

Expand All @@ -34,7 +34,7 @@ case object NativeShuffledHashJoinExecProvider {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
buildSide: BuildSide,
buildSide: JoinBuildSide,
isSkewJoin: Boolean): NativeShuffledHashJoinBase = {

import org.apache.spark.rdd.RDD
Expand All @@ -47,7 +47,7 @@ case object NativeShuffledHashJoinExecProvider {
override val leftKeys: Seq[Expression],
override val rightKeys: Seq[Expression],
override val joinType: JoinType,
buildSide: BuildSide,
buildSide: JoinBuildSide,
skewJoin: Boolean)
extends NativeShuffledHashJoinBase(left, right, leftKeys, rightKeys, joinType, buildSide)
with org.apache.spark.sql.execution.joins.ShuffledJoin {
Expand Down Expand Up @@ -87,12 +87,11 @@ case object NativeShuffledHashJoinExecProvider {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
buildSide: BuildSide,
buildSide: JoinBuildSide,
isSkewJoin: Boolean): NativeShuffledHashJoinBase = {

import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft, JoinBuildRight}
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.execution.auron.plan.BuildLeft
import org.apache.spark.sql.execution.auron.plan.BuildRight
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec

case class NativeShuffledHashJoinExec(
Expand All @@ -101,7 +100,7 @@ case object NativeShuffledHashJoinExecProvider {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
buildSide: BuildSide)
buildSide: JoinBuildSide)
extends NativeShuffledHashJoinBase(left, right, leftKeys, rightKeys, joinType, buildSide)
with org.apache.spark.sql.execution.joins.ShuffledJoin {

Expand All @@ -112,8 +111,8 @@ case object NativeShuffledHashJoinExecProvider {

override def outputOrdering: Seq[SortOrder] = {
val sparkBuildSide = buildSide match {
case BuildLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
case BuildRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
case JoinBuildLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
case JoinBuildRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
}
val shj =
ShuffledHashJoinExec(leftKeys, rightKeys, joinType, sparkBuildSide, None, left, right)
Expand All @@ -135,12 +134,11 @@ case object NativeShuffledHashJoinExecProvider {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
buildSide: BuildSide,
buildSide: JoinBuildSide,
isSkewJoin: Boolean): NativeShuffledHashJoinBase = {

import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft, JoinBuildRight}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.auron.plan.BuildLeft
import org.apache.spark.sql.execution.auron.plan.BuildRight
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec

case class NativeShuffledHashJoinExec(
Expand All @@ -149,7 +147,7 @@ case object NativeShuffledHashJoinExecProvider {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
buildSide: BuildSide)
buildSide: JoinBuildSide)
extends NativeShuffledHashJoinBase(
left,
right,
Expand All @@ -160,8 +158,8 @@ case object NativeShuffledHashJoinExecProvider {

private def shj: ShuffledHashJoinExec = {
val sparkBuildSide = buildSide match {
case BuildLeft => org.apache.spark.sql.execution.joins.BuildLeft
case BuildRight => org.apache.spark.sql.execution.joins.BuildRight
case JoinBuildLeft => org.apache.spark.sql.execution.joins.BuildLeft
case JoinBuildRight => org.apache.spark.sql.execution.joins.BuildRight
}
ShuffledHashJoinExec(leftKeys, rightKeys, joinType, sparkBuildSide, None, left, right)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ package org.apache.spark.sql.auron

import org.apache.commons.lang3.reflect.MethodUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.auron.join.JoinBuildSides.JoinBuildSide
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
import org.apache.spark.sql.execution.auron.plan.BuildSide
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
Expand All @@ -44,7 +44,7 @@ object AuronConvertStrategy extends Logging {
val neverConvertReasonTag: TreeNodeTag[String] = TreeNodeTag("auron.never.convert.reason")
val childOrderingRequiredTag: TreeNodeTag[Boolean] = TreeNodeTag(
"auron.child.ordering.required")
val joinSmallerSideTag: TreeNodeTag[BuildSide] = TreeNodeTag("auron.join.smallerSide")
val joinSmallerSideTag: TreeNodeTag[JoinBuildSide] = TreeNodeTag("auron.join.smallerSide")

def apply(exec: SparkPlan): Unit = {
exec.foreach(_.setTagValue(convertibleTag, true))
Expand Down
Loading
Loading