Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ use datafusion::physical_plan::joins::{
};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::{NullEquality, ScalarValue};
use datafusion_execution::TaskContext;
use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_physical_expr::PhysicalExprRef;
use datafusion_physical_expr::expressions::Literal;

Expand Down Expand Up @@ -1125,6 +1128,138 @@ impl JoinFuzzTestCase {
}
}

/// Fuzz test: compare SMJ (with spilling) against HJ (no spill) for filtered
/// outer joins under memory pressure. This exercises the deferred filtering +
/// spill read-back path that unit tests can't easily cover with random data.
#[tokio::test]
async fn test_filtered_join_spill_fuzz() {
let join_types = [JoinType::Left, JoinType::Right, JoinType::Full];

let runtime_spill = RuntimeEnvBuilder::new()
.with_memory_limit(4096, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
)
.build_arc()
.unwrap();

for join_type in &join_types {
for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] {
let input1 = make_staggered_batches_i32(1000, left_extra);
let input2 = make_staggered_batches_i32(1000, right_extra);

let schema1 = input1[0].schema();
let schema2 = input2[0].schema();
let filter = col_lt_col_filter(schema1.clone(), schema2.clone());

let on = vec![
(
Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _,
Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _,
),
(
Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _,
Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _,
),
];

for batch_size in [2, 49, 100] {
let session_config = SessionConfig::new().with_batch_size(batch_size);

// HJ baseline (no memory limit)
let left_hj = MemorySourceConfig::try_new_exec(
std::slice::from_ref(&input1),
schema1.clone(),
None,
)
.unwrap();
let right_hj = MemorySourceConfig::try_new_exec(
std::slice::from_ref(&input2),
schema2.clone(),
None,
)
.unwrap();
let hj = Arc::new(
HashJoinExec::try_new(
left_hj,
right_hj,
on.clone(),
Some(filter.clone()),
join_type,
None,
PartitionMode::Partitioned,
NullEquality::NullEqualsNothing,
false,
)
.unwrap(),
);
let ctx_hj = SessionContext::new_with_config(session_config.clone());
let hj_collected = collect(hj, ctx_hj.task_ctx()).await.unwrap();

// SMJ with spilling
let left_smj = MemorySourceConfig::try_new_exec(
std::slice::from_ref(&input1),
schema1.clone(),
None,
)
.unwrap();
let right_smj = MemorySourceConfig::try_new_exec(
std::slice::from_ref(&input2),
schema2.clone(),
None,
)
.unwrap();
let smj = Arc::new(
SortMergeJoinExec::try_new(
left_smj,
right_smj,
on.clone(),
Some(filter.clone()),
*join_type,
vec![SortOptions::default(); on.len()],
NullEquality::NullEqualsNothing,
)
.unwrap(),
);
let task_ctx_spill = Arc::new(
TaskContext::default()
.with_session_config(session_config)
.with_runtime(Arc::clone(&runtime_spill)),
);
let smj_collected = collect(smj, task_ctx_spill).await.unwrap();

let hj_rows: usize = hj_collected.iter().map(|b| b.num_rows()).sum();
let smj_rows: usize = smj_collected.iter().map(|b| b.num_rows()).sum();

assert_eq!(
hj_rows, smj_rows,
"Row count mismatch for {join_type:?} batch_size={batch_size} \
left_extra={left_extra} right_extra={right_extra}: \
HJ={hj_rows} SMJ={smj_rows}"
);

if hj_rows > 0 {
let hj_fmt =
pretty_format_batches(&hj_collected).unwrap().to_string();
let smj_fmt =
pretty_format_batches(&smj_collected).unwrap().to_string();

let mut hj_sorted: Vec<&str> = hj_fmt.trim().lines().collect();
hj_sorted.sort_unstable();
let mut smj_sorted: Vec<&str> = smj_fmt.trim().lines().collect();
smj_sorted.sort_unstable();

assert_eq!(
hj_sorted, smj_sorted,
"Content mismatch for {join_type:?} batch_size={batch_size} \
left_extra={left_extra} right_extra={right_extra}"
);
}
}
}
}
}

/// Return randomly sized record batches with:
/// two sorted int32 columns 'a', 'b' ranged from 0..99 as join columns
/// two random int32 columns 'x', 'y' as other columns
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-plan/src/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ mod cross_join;
mod hash_join;
mod nested_loop_join;
mod piecewise_merge_join;
pub(crate) mod semi_anti_sort_merge_join;
pub(crate) mod semi_anti_mark_sort_merge_join;
mod sort_merge_join;
mod stream_join_utils;
mod symmetric_hash_join;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
// specific language governing permissions and limitations
// under the License.

//! Specialized Sort Merge Join stream for Semi/Anti joins.
//! Specialized Sort Merge Join stream for Semi/Anti/Mark joins.
//!
//! Used internally by `SortMergeJoinExec` for semi/anti join types.
//! Used internally by `SortMergeJoinExec` for semi/anti/mark join types.

pub(crate) mod stream;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
// specific language governing permissions and limitations
// under the License.

//! Sort-merge join stream specialized for semi/anti joins.
//! Sort-merge join stream specialized for semi/anti/mark joins.
//!
//! Instantiated by [`SortMergeJoinExec`](crate::joins::sort_merge_join::SortMergeJoinExec)
//! when the join type is `LeftSemi`, `LeftAnti`, `RightSemi`, or `RightAnti`.
//! when the join type is `LeftSemi`, `LeftAnti`, `RightSemi`, `RightAnti`,
//! `LeftMark`, or `RightMark`.
//!
//! # Motivation
//!
Expand All @@ -36,7 +37,8 @@
//!
//! For `Left*` join types, left is outer and right is inner.
//! For `Right*` join types, right is outer and left is inner.
//! The output schema always equals the outer side's schema.
//! The output schema always equals the outer side's schema (for semi/anti)
//! or the outer side's schema plus a boolean mark column (for mark joins).
//!
//! # Algorithm
//!
Expand Down Expand Up @@ -75,6 +77,7 @@
//! On emit:
//! Semi → filter_record_batch(outer_batch, &matched)
//! Anti → filter_record_batch(outer_batch, &NOT(matched))
//! Mark → outer_batch + matched as boolean column
//! ```
//!
//! ## Batch boundaries
Expand Down Expand Up @@ -245,10 +248,8 @@ enum PendingBoundary {
Filtered { saved_keys: Vec<ArrayRef> },
}

pub(crate) struct SemiAntiSortMergeJoinStream {
// Decomposed from JoinType to avoid matching on 4 variants in hot paths.
// true for semi (emit matched), false for anti (emit unmatched).
is_semi: bool,
pub(crate) struct SemiAntiMarkSortMergeJoinStream {
join_type: JoinType,

// Input streams — in the nested-loop model that sort-merge join
// implements, "outer" is the driving loop and "inner" is probed for
Expand Down Expand Up @@ -330,7 +331,7 @@ pub(crate) struct SemiAntiSortMergeJoinStream {
batch_emitted: bool,
}

impl SemiAntiSortMergeJoinStream {
impl SemiAntiMarkSortMergeJoinStream {
#[expect(clippy::too_many_arguments)]
pub fn try_new(
schema: SchemaRef,
Expand All @@ -350,8 +351,22 @@ impl SemiAntiSortMergeJoinStream {
spill_manager: SpillManager,
runtime_env: Arc<datafusion_execution::runtime_env::RuntimeEnv>,
) -> Result<Self> {
let is_semi = matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi);
let outer_is_left = matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti);
debug_assert!(
matches!(
join_type,
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
),
"SemiAntiMarkSortMergeJoinStream does not handle {join_type:?}"
);
let outer_is_left = matches!(
join_type,
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark
);

let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition);
let input_batches =
Expand All @@ -360,7 +375,7 @@ impl SemiAntiSortMergeJoinStream {
let baseline_metrics = BaselineMetrics::new(metrics, partition);

Ok(Self {
is_semi,
join_type,
outer,
inner,
outer_batch: None,
Expand Down Expand Up @@ -492,17 +507,37 @@ impl SemiAntiSortMergeJoinStream {

// finish() converts the bit-packed builder directly to a
// BooleanBuffer — no iteration or repacking needed.
let selection = BooleanArray::new(self.matched.finish(), None);

let selection = if self.is_semi {
selection
} else {
not(&selection)?
};

let filtered = filter_record_batch(batch, &selection)?;
if filtered.num_rows() > 0 {
self.coalescer.push_batch(filtered)?;
let matched_buf = self.matched.finish();

match self.join_type {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comphead this one is for you. You suggested this in #20806 and this time I finally listened :)

JoinType::LeftMark | JoinType::RightMark => {
// Mark joins emit ALL outer rows with a boolean match column appended.
debug_assert_eq!(
self.schema.fields().len(),
batch.num_columns() + 1,
"Mark join output schema should be outer schema + 1 mark column"
);
let mark_col = Arc::new(BooleanArray::new(matched_buf, None)) as ArrayRef;
let mut columns = batch.columns().to_vec();
columns.push(mark_col);
let output = RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
self.coalescer.push_batch(output)?;
}
JoinType::LeftSemi | JoinType::RightSemi => {
let selection = BooleanArray::new(matched_buf, None);
let filtered = filter_record_batch(batch, &selection)?;
if filtered.num_rows() > 0 {
self.coalescer.push_batch(filtered)?;
}
}
JoinType::LeftAnti | JoinType::RightAnti => {
let selection = not(&BooleanArray::new(matched_buf, None))?;
let filtered = filter_record_batch(batch, &selection)?;
if filtered.num_rows() > 0 {
self.coalescer.push_batch(filtered)?;
}
}
_ => unreachable!(),
}
Ok(())
}
Expand Down Expand Up @@ -1184,7 +1219,7 @@ fn keys_match(

/// Evaluate the join filter for one inner row against a slice of outer rows.
///
/// Free function (not a method on SemiAntiSortMergeJoinStream) so that Rust
/// Free function (not a method on SemiAntiMarkSortMergeJoinStream) so that Rust
/// can split the struct borrow in process_key_match_with_filter: the caller
/// holds &mut self.matched and &self.inner_key_buffer simultaneously, which
/// is impossible if this borrows all of &self.
Expand Down Expand Up @@ -1257,7 +1292,7 @@ fn evaluate_filter_for_inner_row(
}
}

impl Stream for SemiAntiSortMergeJoinStream {
impl Stream for SemiAntiMarkSortMergeJoinStream {
type Item = Result<RecordBatch>;

fn poll_next(
Expand All @@ -1269,7 +1304,7 @@ impl Stream for SemiAntiSortMergeJoinStream {
}
}

impl RecordBatchStream for SemiAntiSortMergeJoinStream {
impl RecordBatchStream for SemiAntiMarkSortMergeJoinStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use super::stream::SemiAntiSortMergeJoinStream;
use super::stream::SemiAntiMarkSortMergeJoinStream;
use crate::ExecutionPlan;
use crate::RecordBatchStream;
use crate::common;
Expand Down Expand Up @@ -149,8 +149,10 @@ impl RecordBatchStream for PendingStream {
}
}

/// Helper: collect all output from a SemiAntiSortMergeJoinStream.
async fn collect_stream(stream: SemiAntiSortMergeJoinStream) -> Result<Vec<RecordBatch>> {
/// Helper: collect all output from a SemiAntiMarkSortMergeJoinStream.
async fn collect_stream(
stream: SemiAntiMarkSortMergeJoinStream,
) -> Result<Vec<RecordBatch>> {
common::collect(Box::pin(stream)).await
}

Expand Down Expand Up @@ -259,7 +261,7 @@ async fn filter_buffer_pending_loses_inner_rows() -> Result<()> {
let inner_schema = inner.schema();
let (reservation, peak_mem_used, spill_manager, runtime_env) =
test_stream_resources(inner_schema, &metrics);
let stream = SemiAntiSortMergeJoinStream::try_new(
let stream = SemiAntiMarkSortMergeJoinStream::try_new(
left_schema, // output schema = outer schema for semi
vec![SortOptions::default()],
NullEquality::NullEqualsNothing,
Expand Down Expand Up @@ -359,7 +361,7 @@ async fn no_filter_boundary_pending_loses_outer_rows() -> Result<()> {
let inner_schema = inner.schema();
let (reservation, peak_mem_used, spill_manager, runtime_env) =
test_stream_resources(inner_schema, &metrics);
let stream = SemiAntiSortMergeJoinStream::try_new(
let stream = SemiAntiMarkSortMergeJoinStream::try_new(
left_schema,
vec![SortOptions::default()],
NullEquality::NullEqualsNothing,
Expand Down Expand Up @@ -473,7 +475,7 @@ async fn filtered_boundary_pending_outer_rows() -> Result<()> {
let inner_schema = inner.schema();
let (reservation, peak_mem_used, spill_manager, runtime_env) =
test_stream_resources(inner_schema, &metrics);
let stream = SemiAntiSortMergeJoinStream::try_new(
let stream = SemiAntiMarkSortMergeJoinStream::try_new(
left_schema,
vec![SortOptions::default()],
NullEquality::NullEqualsNothing,
Expand Down Expand Up @@ -756,7 +758,7 @@ async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> {
Arc::clone(&right_schema),
);

let stream = SemiAntiSortMergeJoinStream::try_new(
let stream = SemiAntiMarkSortMergeJoinStream::try_new(
Arc::clone(&left_schema),
vec![SortOptions::default()],
NullEquality::NullEqualsNothing,
Expand Down
Loading
Loading