Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions native-engine/auron-serde/proto/auron.proto
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ message BroadcastJoinExecNode {
JoinType join_type = 5;
JoinSide broadcast_side = 6;
string cached_build_hash_map_id = 7;
bool is_null_aware_anti_join = 8;
}

message RenameColumnsExecNode {
Expand Down
3 changes: 3 additions & 0 deletions native-engine/auron-serde/src/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.map_err(|_| proto_error("invalid BuildSide"))?,
false,
None,
false,
)?))
}
PhysicalPlanType::SortMergeJoin(sort_merge_join) => {
Expand Down Expand Up @@ -354,6 +355,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.expect("invalid BroadcastSide");

let cached_build_hash_map_id = broadcast_join.cached_build_hash_map_id.clone();
let is_null_aware_anti_join = broadcast_join.is_null_aware_anti_join;

Ok(Arc::new(BroadcastJoinExec::try_new(
schema,
Expand All @@ -368,6 +370,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.map_err(|_| proto_error("invalid BroadcastSide"))?,
true,
Some(cached_build_hash_map_id),
is_null_aware_anti_join,
)?))
}
PhysicalPlanType::Union(union) => {
Expand Down
5 changes: 5 additions & 0 deletions native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub struct BroadcastJoinExec {
schema: SchemaRef,
is_built: bool, // true for BroadcastHashJoin, false for ShuffledHashJoin
cached_build_hash_map_id: Option<String>,
is_null_aware_anti_join: bool,
metrics: ExecutionPlanMetricsSet,
props: OnceCell<PlanProperties>,
}
Expand All @@ -102,6 +103,7 @@ impl BroadcastJoinExec {
broadcast_side: JoinSide,
is_built: bool,
cached_build_hash_map_id: Option<String>,
is_null_aware_anti_join: bool,
) -> Result<Self> {
Ok(Self {
left,
Expand All @@ -112,6 +114,7 @@ impl BroadcastJoinExec {
schema,
is_built,
cached_build_hash_map_id,
is_null_aware_anti_join,
metrics: ExecutionPlanMetricsSet::new(),
props: OnceCell::new(),
})
Expand Down Expand Up @@ -176,6 +179,7 @@ impl BroadcastJoinExec {
sort_options: vec![SortOptions::default(); self.on.len()],
projection,
key_data_types,
is_null_aware_anti_join: self.is_null_aware_anti_join,
})
}

Expand Down Expand Up @@ -279,6 +283,7 @@ impl ExecutionPlan for BroadcastJoinExec {
self.broadcast_side,
self.is_built,
None,
self.is_null_aware_anti_join,
)?))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ impl<const P: JoinerParams> Joiner for SemiJoiner<P> {
.as_ref()
.map(|nb| nb.is_valid(row_idx))
.unwrap_or(true);
if P.mode == Anti && P.probe_is_join_side && !key_is_valid {
if P.mode == Anti
&& P.probe_is_join_side
&& !key_is_valid
&& self.join_params.is_null_aware_anti_join
{
probed_joined.set(row_idx, true);
continue;
}
Expand Down
1 change: 1 addition & 0 deletions native-engine/datafusion-ext-plans/src/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub struct JoinParams {
pub sort_options: Vec<SortOptions>,
pub projection: JoinProjection,
pub batch_size: usize,
pub is_null_aware_anti_join: bool,
}

#[derive(Debug, Clone)]
Expand Down
20 changes: 5 additions & 15 deletions native-engine/datafusion-ext-plans/src/joins/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ mod tests {
JoinSide::Right,
true,
None,
false,
)?)
}
BHJRightProbed => {
Expand All @@ -235,6 +236,7 @@ mod tests {
JoinSide::Left,
true,
None,
false,
)?)
}
SHJLeftProbed => Arc::new(BroadcastJoinExec::try_new(
Expand All @@ -246,6 +248,7 @@ mod tests {
JoinSide::Right,
false,
None,
false,
)?),
SHJRightProbed => Arc::new(BroadcastJoinExec::try_new(
schema,
Expand All @@ -256,6 +259,7 @@ mod tests {
JoinSide::Left,
false,
None,
false,
)?),
};
let columns = columns(&join.schema());
Expand Down Expand Up @@ -617,21 +621,7 @@ mod tests {
Arc::new(Column::new_with_schema("b1", &right.schema())?),
)];

for test_type in [BHJLeftProbed, SHJLeftProbed] {
let (_, batches) =
join_collect(test_type, left.clone(), right.clone(), on.clone(), LeftAnti).await?;
let expected = vec![
"+----+----+----+",
"| a1 | b1 | c1 |",
"+----+----+----+",
"| | 6 | 9 |",
"| 5 | 8 | 11 |",
"+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
}

for test_type in [SMJ, BHJRightProbed, SHJRightProbed] {
for test_type in ALL_TEST_TYPE {
let (_, batches) =
join_collect(test_type, left.clone(), right.clone(), on.clone(), LeftAnti).await?;
let expected = vec![
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ impl SortMergeJoinExec {
sort_options: self.sort_options.clone(),
projection,
batch_size: batch_size(),
is_null_aware_anti_join: false,
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,17 @@ class ShimsImpl extends Shims with Logging {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
broadcastSide: BroadcastSide): NativeBroadcastJoinBase =
broadcastSide: BroadcastSide,
isNullAwareAntiJoin: Boolean): NativeBroadcastJoinBase =
NativeBroadcastJoinExec(
left,
right,
outputPartitioning,
leftKeys,
rightKeys,
joinType,
broadcastSide)
broadcastSide,
isNullAwareAntiJoin)

override def createNativeSortMergeJoinExec(
left: SparkPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ case class NativeBroadcastJoinExec(
override val leftKeys: Seq[Expression],
override val rightKeys: Seq[Expression],
override val joinType: JoinType,
broadcastSide: BroadcastSide)
broadcastSide: BroadcastSide,
isNullAwareAntiJoin: Boolean)
extends NativeBroadcastJoinBase(
left,
right,
outputPartitioning,
leftKeys,
rightKeys,
joinType,
broadcastSide)
broadcastSide,
isNullAwareAntiJoin)
with HashJoin {

override val condition: Option[Expression] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,4 +581,85 @@ class AuronQuerySuite extends AuronQueryTest with BaseAuronSQLSuite with AuronSQ
}
}
}

test("standard LEFT ANTI JOIN includes NULL keys") {
// This test verifies that standard LEFT ANTI JOIN correctly includes NULL keys
// NULL keys should be in the result because NULL never matches anything
withTable("left_table", "right_table") {
sql("""
|CREATE TABLE left_table using parquet AS
|SELECT * FROM VALUES
| (1, 2.0),
| (1, 2.0),
| (2, 1.0),
| (2, 1.0),
| (3, 3.0),
| (null, null),
| (null, 5.0),
| (6, null)
|AS t(a, b)
|""".stripMargin)

sql("""
|CREATE TABLE right_table using parquet AS
|SELECT * FROM VALUES
| (2, 3.0),
| (2, 3.0),
| (3, 2.0),
| (4, 1.0),
| (null, null),
| (null, 5.0),
| (6, null)
|AS t(c, d)
|""".stripMargin)

// Standard LEFT ANTI JOIN should include rows with NULL keys
// Expected: (1, 2.0), (1, 2.0), (null, null), (null, 5.0)
checkSparkAnswer(
"SELECT * FROM left_table LEFT ANTI JOIN right_table ON left_table.a = right_table.c")
}
}

test("left join with NOT IN subquery should filter NULL values") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  == Results ==
  !== Correct Answer - 1 ==   == Spark Answer - 1 ==
   struct<cnt:bigint>         struct<cnt:bigint>
  ![9]                        [99999] (QueryTest.scala:243)

// This test verifies the fix for the NULL handling issue in Anti join.
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
val query =
"""
|WITH t2 AS (
| -- Large table: 100000 rows (0..99999)
| SELECT id AS loan_req_no
| FROM range(0, 100000)
|),
|t1 AS (
| -- Small table: 10 rows that can match t2
| SELECT * FROM VALUES
| (1, 'A'),
| (2, 'B'),
| (3, 'C'),
| (4, 'D'),
| (5, 'E'),
| (6, 'F'),
| (7, 'G'),
| (8, 'H'),
| (9, 'I'),
| (10,'J')
| AS t1(loan_req_no, partner_code)
|),
|blk AS (
| SELECT * FROM VALUES
| ('B'),
| ('Z')
| AS blk(code)
|)
|SELECT
| COUNT(*) AS cnt
|FROM t2
|LEFT JOIN t1
| ON t1.loan_req_no = t2.loan_req_no
|WHERE t1.partner_code NOT IN (SELECT code FROM blk)
|""".stripMargin

checkSparkAnswer(query)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -664,16 +664,23 @@ object AuronConverters extends Logging {
}
}

@sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
def isNullAwareAntiJoin(exec: BroadcastHashJoinExec): Boolean = exec.isNullAwareAntiJoin

@sparkver("3.0")
def isNullAwareAntiJoin(exec: BroadcastHashJoinExec): Boolean = false

def convertBroadcastHashJoinExec(exec: BroadcastHashJoinExec): SparkPlan = {
try {
val (leftKeys, rightKeys, joinType, buildSide, condition, left, right) = (
val (leftKeys, rightKeys, joinType, buildSide, condition, left, right, naaj) = (
exec.leftKeys,
exec.rightKeys,
exec.joinType,
exec.buildSide,
exec.condition,
exec.left,
exec.right)
exec.right,
isNullAwareAntiJoin(exec))
logDebugPlanConversion(
exec,
Seq(
Expand Down Expand Up @@ -702,7 +709,8 @@ object AuronConverters extends Logging {
buildSide match {
case BuildLeft => BroadcastLeft
case BuildRight => BroadcastRight
})
},
naaj)

} catch {
case e @ (_: NotImplementedError | _: Exception) =>
Expand Down Expand Up @@ -744,7 +752,8 @@ object AuronConverters extends Logging {
buildSide match {
case BuildLeft => BroadcastLeft
case BuildRight => BroadcastRight
})
},
isNullAwareAntiJoin = false)

} catch {
case e @ (_: NotImplementedError | _: Exception) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ abstract class Shims {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
broadcastSide: BroadcastSide): NativeBroadcastJoinBase
broadcastSide: BroadcastSide,
isNullAwareAntiJoin: Boolean): NativeBroadcastJoinBase

def createNativeSortMergeJoinExec(
left: SparkPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ abstract class NativeBroadcastJoinBase(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
broadcastSide: BroadcastSide)
broadcastSide: BroadcastSide,
isNullAwareAntiJoin: Boolean)
extends BinaryExecNode
with NativeSupports {

Expand Down Expand Up @@ -174,6 +175,7 @@ abstract class NativeBroadcastJoinBase(
.setJoinType(nativeJoinType)
.setBroadcastSide(nativeBroadcastSide)
.setCachedBuildHashMapId(cachedBuildHashMapId)
.setIsNullAwareAntiJoin(isNullAwareAntiJoin)
.addAllOn(nativeJoinOn.asJava)

pb.PhysicalPlanNode.newBuilder().setBroadcastJoin(broadcastJoinExec).build()
Expand Down