Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
357 changes: 354 additions & 3 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,18 @@ impl StaticFilter for ArrayStaticFilter {
));
}

// Unwrap dictionary-encoded needles when the value type matches
// in_array, evaluating against the dictionary values and mapping
// back via keys.
downcast_dictionary_array! {
v => {
let values_contains = self.contains(v.values().as_ref(), negated)?;
let result = take(&values_contains, v.keys(), None)?;
return Ok(downcast_array(result.as_ref()))
// Only unwrap when the haystack (in_array) type matches
// the dictionary value type
if v.values().data_type() == self.in_array.data_type() {
let values_contains = self.contains(v.values().as_ref(), negated)?;
let result = take(&values_contains, v.keys(), None)?;
return Ok(downcast_array(result.as_ref()));
}
}
_ => {}
}
Expand Down Expand Up @@ -3724,4 +3731,348 @@ mod tests {
assert_eq!(result, &BooleanArray::from(vec![true, false, false]));
Ok(())
}
/// Tests that short-circuit evaluation produces correct results.
/// When all rows match after the first list item, remaining items
/// should be skipped without affecting correctness.
#[test]
fn test_in_list_with_columns_short_circuit() -> Result<()> {
// a IN (b, c) where b already matches every row of a
// The short-circuit should skip evaluating c
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![1, 2, 3])), // b == a for all rows
Arc::new(Int32Array::from(vec![99, 99, 99])),
],
)?;

let col_a = col("a", &schema)?;
let list = vec![col("b", &schema)?, col("c", &schema)?];
let expr = make_in_list_with_columns(col_a, list, false);

let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_boolean_array(&result);
assert_eq!(result, &BooleanArray::from(vec![true, true, true]));
Ok(())
}

/// Short-circuit must NOT skip when nulls are present (three-valued logic).
/// Even if all non-null values are true, null rows keep the result as null.
#[test]
fn test_in_list_with_columns_short_circuit_with_nulls() -> Result<()> {
// a IN (b, c) where a has nulls
// Even if b matches all non-null rows, result should preserve nulls
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])),
Arc::new(Int32Array::from(vec![1, 2, 3])), // matches non-null rows
Arc::new(Int32Array::from(vec![99, 99, 99])),
],
)?;

let col_a = col("a", &schema)?;
let list = vec![col("b", &schema)?, col("c", &schema)?];
let expr = make_in_list_with_columns(col_a, list, false);

let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_boolean_array(&result);
// row 0: 1 IN (1, 99) → true
// row 1: NULL IN (2, 99) → NULL
// row 2: 3 IN (3, 99) → true
assert_eq!(
result,
&BooleanArray::from(vec![Some(true), None, Some(true)])
);
Ok(())
}

/// Tests the make_comparator + collect_bool fallback path using
/// struct column references (nested types don't support arrow_eq).
#[test]
fn test_in_list_with_columns_struct() -> Result<()> {
let struct_fields = Fields::from(vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Utf8, false),
]);
let struct_dt = DataType::Struct(struct_fields.clone());

let schema = Schema::new(vec![
Field::new("a", struct_dt.clone(), true),
Field::new("b", struct_dt.clone(), false),
Field::new("c", struct_dt.clone(), false),
]);

// a: [{1,"a"}, {2,"b"}, NULL, {4,"d"}]
// b: [{1,"a"}, {9,"z"}, {3,"c"}, {4,"d"}]
// c: [{9,"z"}, {2,"b"}, {9,"z"}, {9,"z"}]
let a = Arc::new(StructArray::new(
struct_fields.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3, 4])),
Arc::new(StringArray::from(vec!["a", "b", "c", "d"])),
],
Some(vec![true, true, false, true].into()),
));
let b = Arc::new(StructArray::new(
struct_fields.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 9, 3, 4])),
Arc::new(StringArray::from(vec!["a", "z", "c", "d"])),
],
None,
));
let c = Arc::new(StructArray::new(
struct_fields.clone(),
vec![
Arc::new(Int32Array::from(vec![9, 2, 9, 9])),
Arc::new(StringArray::from(vec!["z", "b", "z", "z"])),
],
None,
));

let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b, c])?;

let col_a = col("a", &schema)?;
let list = vec![col("b", &schema)?, col("c", &schema)?];
let expr = make_in_list_with_columns(col_a, list, false);

let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_boolean_array(&result);
// row 0: {1,"a"} IN ({1,"a"}, {9,"z"}) → true (matches b)
// row 1: {2,"b"} IN ({9,"z"}, {2,"b"}) → true (matches c)
// row 2: NULL IN ({3,"c"}, {9,"z"}) → NULL
// row 3: {4,"d"} IN ({4,"d"}, {9,"z"}) → true (matches b)
assert_eq!(
result,
&BooleanArray::from(vec![Some(true), Some(true), None, Some(true)])
);

// Also test NOT IN
let col_a = col("a", &schema)?;
let list = vec![col("b", &schema)?, col("c", &schema)?];
let expr = make_in_list_with_columns(col_a, list, true);

let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_boolean_array(&result);
// row 0: {1,"a"} NOT IN ({1,"a"}, {9,"z"}) → false
// row 1: {2,"b"} NOT IN ({9,"z"}, {2,"b"}) → false
// row 2: NULL NOT IN ({3,"c"}, {9,"z"}) → NULL
// row 3: {4,"d"} NOT IN ({4,"d"}, {9,"z"}) → false
assert_eq!(
result,
&BooleanArray::from(vec![Some(false), Some(false), None, Some(false)])
);
Ok(())
}

// -----------------------------------------------------------------------
// Tests for try_new_from_array: evaluates `needle IN in_array`.
//
// This exercises the code path used by HashJoin dynamic filter pushdown,
// where in_array is built directly from the join's build-side arrays.
// Unlike try_new (used by SQL IN expressions), which always produces a
// non-Dictionary in_array because evaluate_list() flattens Dictionary
// scalars, try_new_from_array passes the array directly and can produce
// a Dictionary in_array.
// -----------------------------------------------------------------------

fn wrap_in_dict(array: ArrayRef) -> ArrayRef {
let keys = Int32Array::from((0..array.len() as i32).collect::<Vec<_>>());
Arc::new(DictionaryArray::new(keys, array))
}

/// Evaluates `needle IN in_array` via try_new_from_array, the same
/// path used by HashJoin dynamic filter pushdown (not the SQL literal
/// IN path which goes through try_new).
fn eval_in_list_from_array(
needle: ArrayRef,
in_array: ArrayRef,
) -> Result<BooleanArray> {
let schema =
Schema::new(vec![Field::new("a", needle.data_type().clone(), false)]);
let col_a = col("a", &schema)?;
let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?)
as Arc<dyn PhysicalExpr>;
let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
Ok(as_boolean_array(&result).clone())
}

#[test]
fn test_in_list_from_array_type_combinations() -> Result<()> {
use arrow::compute::cast;

// All cases: needle[0] and needle[2] match, needle[1] does not.
let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]);

// Base arrays cast to each target type
let base_in = Arc::new(Int64Array::from(vec![1i64, 2, 3])) as ArrayRef;
let base_needle = Arc::new(Int64Array::from(vec![1i64, 4, 2])) as ArrayRef;

// Test all specializations in instantiate_static_filter
let primitive_types = vec![
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Float32,
DataType::Float64,
];

for dt in &primitive_types {
let in_array = cast(&base_in, dt)?;
let needle = cast(&base_needle, dt)?;

// T in_array, T needle
assert_eq!(
expected,
eval_in_list_from_array(Arc::clone(&needle), Arc::clone(&in_array))?,
"same-type failed for {dt:?}"
);

// T in_array, Dict(Int32, T) needle
assert_eq!(
expected,
eval_in_list_from_array(wrap_in_dict(needle), in_array)?,
"dict-needle failed for {dt:?}"
);
}

// Utf8 (falls through to ArrayStaticFilter)
let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef;

// Utf8 in_array, Utf8 needle
assert_eq!(
expected,
eval_in_list_from_array(Arc::clone(&utf8_needle), Arc::clone(&utf8_in),)?
);

// Utf8 in_array, Dict(Utf8) needle
assert_eq!(
expected,
eval_in_list_from_array(
wrap_in_dict(Arc::clone(&utf8_needle)),
Arc::clone(&utf8_in),
)?
);

// Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug
assert_eq!(
expected,
eval_in_list_from_array(
wrap_in_dict(Arc::clone(&utf8_needle)),
wrap_in_dict(Arc::clone(&utf8_in)),
)?
);

// Struct in_array, Struct needle: multi-column join
let struct_fields = Fields::from(vec![
Field::new("c0", DataType::Utf8, true),
Field::new("c1", DataType::Int64, true),
]);
let make_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
let pairs: Vec<(FieldRef, ArrayRef)> =
struct_fields.iter().cloned().zip([c0, c1]).collect();
Arc::new(StructArray::from(pairs))
};
assert_eq!(
expected,
eval_in_list_from_array(
make_struct(
Arc::clone(&utf8_needle),
Arc::new(Int64Array::from(vec![1, 4, 2])),
),
make_struct(
Arc::clone(&utf8_in),
Arc::new(Int64Array::from(vec![1, 2, 3])),
),
)?
);

// Struct with Dict fields: multi-column Dict join
let dict_struct_fields = Fields::from(vec![
Field::new(
"c0",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
true,
),
Field::new("c1", DataType::Int64, true),
]);
let make_dict_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
let pairs: Vec<(FieldRef, ArrayRef)> =
dict_struct_fields.iter().cloned().zip([c0, c1]).collect();
Arc::new(StructArray::from(pairs))
};
assert_eq!(
expected,
eval_in_list_from_array(
make_dict_struct(
wrap_in_dict(Arc::clone(&utf8_needle)),
Arc::new(Int64Array::from(vec![1, 4, 2])),
),
make_dict_struct(
wrap_in_dict(Arc::clone(&utf8_in)),
Arc::new(Int64Array::from(vec![1, 2, 3])),
),
)?
);

Ok(())
}

#[test]
fn test_in_list_from_array_type_mismatch_errors() -> Result<()> {
// Utf8 needle, Dict(Utf8) in_array
let err = eval_in_list_from_array(
Arc::new(StringArray::from(vec!["a", "d", "b"])),
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
)
.unwrap_err()
.to_string();
assert!(
err.contains("Can't compare arrays of different types"),
"{err}"
);

// Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter
// rejects the Utf8 dictionary values at construction time
let err = eval_in_list_from_array(
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))),
Arc::new(Int64Array::from(vec![1, 2, 3])),
)
.unwrap_err()
.to_string();
assert!(err.contains("Failed to downcast"), "{err}");

// Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different
// value types, make_comparator rejects the comparison
let err = eval_in_list_from_array(
wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))),
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
)
.unwrap_err()
.to_string();
assert!(
err.contains("Can't compare arrays of different types"),
"{err}"
);
Ok(())
}
}
Loading
Loading