Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 210 additions & 3 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,18 @@ impl StaticFilter for ArrayStaticFilter {
));
}

// Unwrap dictionary-encoded needles when the value type matches
// in_array, evaluating against the dictionary values and mapping
// back via keys.
downcast_dictionary_array! {
v => {
let values_contains = self.contains(v.values().as_ref(), negated)?;
let result = take(&values_contains, v.keys(), None)?;
return Ok(downcast_array(result.as_ref()))
// Only unwrap when the haystack (in_array) type matches
// the dictionary value type
if v.values().data_type() == self.in_array.data_type() {
let values_contains = self.contains(v.values().as_ref(), negated)?;
let result = take(&values_contains, v.keys(), None)?;
return Ok(downcast_array(result.as_ref()));
}
}
_ => {}
}
Expand Down Expand Up @@ -3878,4 +3885,204 @@ mod tests {
);
Ok(())
}

// -----------------------------------------------------------------------
// Tests for try_new_from_array: evaluates `needle IN in_array`.
//
// This exercises the code path used by HashJoin dynamic filter pushdown,
// where in_array is built directly from the join's build-side arrays.
// Unlike try_new (used by SQL IN expressions), which always produces a
// non-Dictionary in_array because evaluate_list() flattens Dictionary
// scalars, try_new_from_array passes the array directly and can produce
// a Dictionary in_array.
// -----------------------------------------------------------------------

fn wrap_in_dict(array: ArrayRef) -> ArrayRef {
let keys = Int32Array::from((0..array.len() as i32).collect::<Vec<_>>());
Arc::new(DictionaryArray::new(keys, array))
}

/// Evaluates `needle IN in_array` via try_new_from_array, the same
/// path used by HashJoin dynamic filter pushdown (not the SQL literal
/// IN path which goes through try_new).
fn eval_in_list_from_array(
needle: ArrayRef,
in_array: ArrayRef,
) -> Result<BooleanArray> {
let schema =
Schema::new(vec![Field::new("a", needle.data_type().clone(), false)]);
let col_a = col("a", &schema)?;
let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?)
as Arc<dyn PhysicalExpr>;
let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
Ok(as_boolean_array(&result).clone())
}

#[test]
fn test_in_list_from_array_type_combinations() -> Result<()> {
use arrow::compute::cast;

// All cases: needle[0] and needle[2] match, needle[1] does not.
let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]);

// Base arrays cast to each target type
let base_in = Arc::new(Int64Array::from(vec![1i64, 2, 3])) as ArrayRef;
let base_needle = Arc::new(Int64Array::from(vec![1i64, 4, 2])) as ArrayRef;

// Test all specializations in instantiate_static_filter
let primitive_types = vec![
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Float32,
DataType::Float64,
];

for dt in &primitive_types {
let in_array = cast(&base_in, dt)?;
let needle = cast(&base_needle, dt)?;

// T in_array, T needle
assert_eq!(
expected,
eval_in_list_from_array(Arc::clone(&needle), Arc::clone(&in_array))?,
"same-type failed for {dt:?}"
);

// T in_array, Dict(Int32, T) needle
assert_eq!(
expected,
eval_in_list_from_array(wrap_in_dict(needle), in_array)?,
"dict-needle failed for {dt:?}"
);
}

// Utf8 (falls through to ArrayStaticFilter)
let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef;

// Utf8 in_array, Utf8 needle
assert_eq!(
expected,
eval_in_list_from_array(Arc::clone(&utf8_needle), Arc::clone(&utf8_in),)?
);

// Utf8 in_array, Dict(Utf8) needle
assert_eq!(
expected,
eval_in_list_from_array(
wrap_in_dict(Arc::clone(&utf8_needle)),
Arc::clone(&utf8_in),
)?
);

// Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug
assert_eq!(
expected,
eval_in_list_from_array(
wrap_in_dict(Arc::clone(&utf8_needle)),
wrap_in_dict(Arc::clone(&utf8_in)),
)?
);

// Struct in_array, Struct needle: multi-column join
let struct_fields = Fields::from(vec![
Field::new("c0", DataType::Utf8, true),
Field::new("c1", DataType::Int64, true),
]);
let make_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
let pairs: Vec<(FieldRef, ArrayRef)> =
struct_fields.iter().cloned().zip([c0, c1]).collect();
Arc::new(StructArray::from(pairs))
};
assert_eq!(
expected,
eval_in_list_from_array(
make_struct(
Arc::clone(&utf8_needle),
Arc::new(Int64Array::from(vec![1, 4, 2])),
),
make_struct(
Arc::clone(&utf8_in),
Arc::new(Int64Array::from(vec![1, 2, 3])),
),
)?
);

// Struct with Dict fields: multi-column Dict join
let dict_struct_fields = Fields::from(vec![
Field::new(
"c0",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
true,
),
Field::new("c1", DataType::Int64, true),
]);
let make_dict_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
let pairs: Vec<(FieldRef, ArrayRef)> =
dict_struct_fields.iter().cloned().zip([c0, c1]).collect();
Arc::new(StructArray::from(pairs))
};
assert_eq!(
expected,
eval_in_list_from_array(
make_dict_struct(
wrap_in_dict(Arc::clone(&utf8_needle)),
Arc::new(Int64Array::from(vec![1, 4, 2])),
),
make_dict_struct(
wrap_in_dict(Arc::clone(&utf8_in)),
Arc::new(Int64Array::from(vec![1, 2, 3])),
),
)?
);

Ok(())
}

#[test]
fn test_in_list_from_array_type_mismatch_errors() -> Result<()> {
// Utf8 needle, Dict(Utf8) in_array
let err = eval_in_list_from_array(
Arc::new(StringArray::from(vec!["a", "d", "b"])),
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
)
.unwrap_err()
.to_string();
assert!(
err.contains("Can't compare arrays of different types"),
"{err}"
);

// Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter
// rejects the Utf8 dictionary values at construction time
let err = eval_in_list_from_array(
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))),
Arc::new(Int64Array::from(vec![1, 2, 3])),
)
.unwrap_err()
.to_string();
assert!(err.contains("Failed to downcast"), "{err}");

// Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different
// value types, make_comparator rejects the comparison
let err = eval_in_list_from_array(
wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))),
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
)
.unwrap_err()
.to_string();
assert!(
err.contains("Can't compare arrays of different types"),
"{err}"
);

Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -918,13 +918,18 @@ CREATE EXTERNAL TABLE dict_filter_bug
STORED AS PARQUET
LOCATION 'test_files/scratch/parquet_filter_pushdown/dict_filter_bug.parquet';

query error Can't compare arrays of different types
query TR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

SELECT t.tag1, t.value
FROM dict_filter_bug t
JOIN (VALUES ('A'), ('B')) AS v(c1)
ON t.tag1 = v.c1
ORDER BY t.tag1, t.value
LIMIT 4;
----
A 0
A 26
A 52
A 78

# Cleanup
statement ok
Expand Down
Loading