Skip to content
Draft
84 changes: 32 additions & 52 deletions datafusion/common/src/nested_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use std::{collections::HashSet, sync::Arc};
///
/// ## Field Matching Strategy
/// - **By Name**: Source struct fields are matched to target fields by name (case-sensitive)
/// - **By Position**: When there is no name overlap and the field counts match, fields are cast by index
/// - **No Positional Mapping**: Structs with no overlapping field names are rejected
/// - **Type Adaptation**: When a matching field is found, it is recursively cast to the target field's type
/// - **Missing Fields**: Target fields not present in the source are filled with null values
/// - **Extra Fields**: Source fields not present in the target are ignored
Expand Down Expand Up @@ -67,24 +67,16 @@ fn cast_struct_column(
if let Some(source_struct) = source_col.as_any().downcast_ref::<StructArray>() {
let source_fields = source_struct.fields();
validate_struct_compatibility(source_fields, target_fields)?;
let has_overlap = has_one_of_more_common_fields(source_fields, target_fields);

let mut fields: Vec<Arc<Field>> = Vec::with_capacity(target_fields.len());
let mut arrays: Vec<ArrayRef> = Vec::with_capacity(target_fields.len());
let num_rows = source_col.len();

// Iterate target fields and pick source child either by name (when fields overlap)
// or by position (when there is no name overlap).
for (index, target_child_field) in target_fields.iter().enumerate() {
// Iterate target fields and pick source child by name when present.
for target_child_field in target_fields.iter() {
fields.push(Arc::clone(target_child_field));

// Determine the source child column: by name when overlapping names exist,
// otherwise by position.
let source_child_opt: Option<&ArrayRef> = if has_overlap {
source_struct.column_by_name(target_child_field.name())
} else {
Some(source_struct.column(index))
};
let source_child_opt =
source_struct.column_by_name(target_child_field.name());

match source_child_opt {
Some(source_child_col) => {
Expand Down Expand Up @@ -230,20 +222,11 @@ pub fn validate_struct_compatibility(
) -> Result<()> {
let has_overlap = has_one_of_more_common_fields(source_fields, target_fields);
if !has_overlap {
if source_fields.len() != target_fields.len() {
return _plan_err!(
"Cannot cast struct with {} fields to {} fields without name overlap; positional mapping is ambiguous",
source_fields.len(),
target_fields.len()
);
}

for (source_field, target_field) in source_fields.iter().zip(target_fields.iter())
{
validate_field_compatibility(source_field, target_field)?;
}

return Ok(());
return _plan_err!(
"Cannot cast struct with {} fields to {} fields because there is no field name overlap",
source_fields.len(),
target_fields.len()
);
}

// Check compatibility for each target field
Expand Down Expand Up @@ -323,7 +306,11 @@ fn validate_field_compatibility(
Ok(())
}

fn has_one_of_more_common_fields(
/// Check if two field lists have at least one common field by name.
///
/// This is useful for validating struct compatibility when casting between structs,
/// ensuring that source and target fields have overlapping names.
pub fn has_one_of_more_common_fields(
source_fields: &[FieldRef],
target_fields: &[FieldRef],
) -> bool {
Expand Down Expand Up @@ -546,7 +533,7 @@ mod tests {
}

#[test]
fn test_validate_struct_compatibility_positional_no_overlap_mismatch_len() {
fn test_validate_struct_compatibility_no_overlap_mismatch_len() {
let source_fields = vec![
arc_field("left", DataType::Int32),
arc_field("right", DataType::Int32),
Expand All @@ -556,7 +543,7 @@ mod tests {
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("positional mapping is ambiguous"));
assert_contains!(error_msg, "no field name overlap");
}

#[test]
Expand Down Expand Up @@ -665,21 +652,21 @@ mod tests {
}

#[test]
fn test_validate_struct_compatibility_positional_with_type_mismatch() {
// Source struct: {left: Struct} - nested struct
let source_fields =
vec![arc_struct_field("left", vec![field("x", DataType::Int32)])];
fn test_validate_struct_compatibility_no_overlap_equal_len() {
let source_fields = vec![
arc_field("left", DataType::Int32),
arc_field("right", DataType::Utf8),
];

// Target struct: {alpha: Int32} (no name overlap, incompatible type at position 0)
let target_fields = vec![arc_field("alpha", DataType::Int32)];
let target_fields = vec![
arc_field("alpha", DataType::Int32),
arc_field("beta", DataType::Utf8),
];

let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert_contains!(
error_msg,
"Cannot cast struct field 'alpha' from type Struct(\"x\": Int32) to type Int32"
);
assert_contains!(error_msg, "no field name overlap");
}

#[test]
Expand Down Expand Up @@ -948,7 +935,7 @@ mod tests {
}

#[test]
fn test_cast_struct_positional_when_no_overlap() {
fn test_cast_struct_no_overlap_rejected() {
let first = Arc::new(Int32Array::from(vec![Some(10), Some(20)])) as ArrayRef;
let second =
Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])) as ArrayRef;
Expand All @@ -964,17 +951,10 @@ mod tests {
vec![field("a", DataType::Int64), field("b", DataType::Utf8)],
);

let result =
cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap();
let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();

let a_col = get_column_as!(&struct_array, "a", Int64Array);
assert_eq!(a_col.value(0), 10);
assert_eq!(a_col.value(1), 20);

let b_col = get_column_as!(&struct_array, "b", StringArray);
assert_eq!(b_col.value(0), "alpha");
assert_eq!(b_col.value(1), "beta");
let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert_contains!(error_msg, "no field name overlap");
}

#[test]
Expand Down
17 changes: 10 additions & 7 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use std::ops::Not;
use std::sync::Arc;

use datafusion_common::config::ConfigOptions;
use datafusion_common::nested_struct::has_one_of_more_common_fields;
use datafusion_common::{
DFSchema, DataFusionError, Result, ScalarValue, exec_datafusion_err, internal_err,
};
Expand Down Expand Up @@ -657,6 +658,11 @@ impl ConstEvaluator {
return false;
}

// Skip const-folding when there is no field name overlap
if !has_one_of_more_common_fields(&source_fields, target_fields) {
return false;
}

// Don't const-fold struct casts with empty (0-row) literals
// The simplifier uses a 1-row input batch, which causes dimension mismatches
// when evaluating 0-row struct literals
Expand Down Expand Up @@ -5220,7 +5226,7 @@ mod tests {
#[test]
fn test_struct_cast_different_names_same_count() {
// Test struct cast with same field count but different names
// Field count matches; simplification should succeed
// Field count matches; simplification should be skipped because names do not overlap

let source_fields = Fields::from(vec![
Arc::new(Field::new("a", DataType::Int32, true)),
Expand All @@ -5237,14 +5243,11 @@ mod tests {
let simplifier =
ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema()));

// The cast should be simplified since field counts match
// The cast should remain unchanged because there is no name overlap
let result = simplifier.simplify(expr.clone()).unwrap();
// Struct casts with same field count are const-folded to literals
assert!(matches!(result, Expr::Literal(_, _)));
// Ensure the simplifier made a change (not identical to original)
assert_ne!(
assert_eq!(
result, expr,
"Struct cast with different names but same field count should be simplified"
"Struct cast with different names but same field count should not be const-folded"
);
}

Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ statement ok
CREATE TABLE join_t3(s3 struct<id INT>)
AS VALUES
(NULL),
(struct(1)),
(struct(2));
({id: 1}),
({id: 2});

statement ok
CREATE TABLE join_t4(s4 struct<id INT>)
AS VALUES
(NULL),
(struct(2)),
(struct(3));
({id: 2}),
({id: 3});

# Left semi anti join

Expand Down
56 changes: 29 additions & 27 deletions datafusion/sqllogictest/test_files/struct.slt
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ CREATE TABLE struct_values (
s1 struct<INT>,
s2 struct<a INT,b VARCHAR>
) AS VALUES
(struct(1), struct(1, 'string1')),
(struct(2), struct(2, 'string2')),
(struct(3), struct(3, 'string3'))
(struct(1), struct(1 AS a, 'string1' AS b)),
(struct(2), struct(2 AS a, 'string2' AS b)),
(struct(3), struct(3 AS a, 'string3' AS b))
;

query ??
Expand Down Expand Up @@ -397,7 +397,8 @@ drop view complex_view;

# struct with different keys r1 and r2 is not valid
statement ok
create table t(a struct<r1 varchar, c int>, b struct<r2 varchar, c float>) as values (struct('red', 1), struct('blue', 2.3));
create table t(a struct<r1 varchar, c int>, b struct<r2 varchar, c float>) as values
(struct('red' AS r1, 1 AS c), struct('blue' AS r2, 2.3 AS c));

# Expect same keys for struct type but got mismatched pair r1,c and r2,c
query error
Expand All @@ -408,7 +409,8 @@ drop table t;

# struct with the same key
statement ok
create table t(a struct<r varchar, c int>, b struct<r varchar, c float>) as values (struct('red', 1), struct('blue', 2.3));
create table t(a struct<r varchar, c int>, b struct<r varchar, c float>) as values
(struct('red' AS r, 1 AS c), struct('blue' AS r, 2.3 AS c));

query T
select arrow_typeof([a, b]) from t;
Expand Down Expand Up @@ -442,18 +444,18 @@ CREATE TABLE struct_values (
s1 struct(a int, b varchar),
s2 struct(a int, b varchar)
) AS VALUES
(row(1, 'red'), row(1, 'string1')),
(row(2, 'blue'), row(2, 'string2')),
(row(3, 'green'), row(3, 'string3'))
({a: 1, b: 'red'}, {a: 1, b: 'string1'}),
({a: 2, b: 'blue'}, {a: 2, b: 'string2'}),
({a: 3, b: 'green'}, {a: 3, b: 'string3'})
;

statement ok
drop table struct_values;

statement ok
create table t (c1 struct(r varchar, b int), c2 struct(r varchar, b float)) as values (
row('red', 2),
row('blue', 2.3)
{r: 'red', b: 2},
{r: 'blue', b: 2.3}
);

query ??
Expand Down Expand Up @@ -501,9 +503,9 @@ CREATE TABLE t (
s1 struct(a int, b varchar),
s2 struct(a float, b varchar)
) AS VALUES
(row(1, 'red'), row(1.1, 'string1')),
(row(2, 'blue'), row(2.2, 'string2')),
(row(3, 'green'), row(33.2, 'string3'))
({a: 1, b: 'red'}, {a: 1.1, b: 'string1'}),
({a: 2, b: 'blue'}, {a: 2.2, b: 'string2'}),
({a: 3, b: 'green'}, {a: 33.2, b: 'string3'})
;

query ?
Expand All @@ -528,9 +530,9 @@ CREATE TABLE t (
s1 struct(a int, b varchar),
s2 struct(a float, b varchar)
) AS VALUES
(row(1, 'red'), row(1.1, 'string1')),
(null, row(2.2, 'string2')),
(row(3, 'green'), row(33.2, 'string3'))
({a: 1, b: 'red'}, {a: 1.1, b: 'string1'}),
(null, {a: 2.2, b: 'string2'}),
({a: 3, b: 'green'}, {a: 33.2, b: 'string3'})
;

query ?
Expand All @@ -553,8 +555,8 @@ drop table t;
# row() with incorrect order - row() is positional, not name-based
statement error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*Arrow error: Cast error: Cannot cast string 'blue' to value of Float32 type
create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values
(row('red', 1), row(2.3, 'blue')),
(row('purple', 1), row('green', 2.3));
({r: 'red', c: 1}, {r: 2.3, c: 'blue'}),
({r: 'purple', c: 1}, {r: 'green', c: 2.3});


##################################
Expand All @@ -568,7 +570,7 @@ select [{r: 'a', c: 1}, {r: 'b', c: 2}];


statement ok
create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values (row('a', 1), row('b', 2.3));
create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values ({r: 'a', c: 1}, {r: 'b', c: 2.3});

query T
select arrow_typeof([a, b]) from t;
Expand All @@ -580,7 +582,7 @@ drop table t;


statement ok
create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, g int)) as values (row('a', 1, 2.3), row('b', 2.3, 2));
create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, g int)) as values ({r: 'a', c: 1, g: 2.3}, {r: 'b', c: 2.3, g: 2});

# type of each column should not coerced but preserve as it is
query T
Expand All @@ -602,7 +604,7 @@ drop table t;
# This tests accessing struct fields using the subscript notation with string literals

statement ok
create table test (struct_field struct(substruct int)) as values (struct(1));
create table test (struct_field struct(substruct int)) as values ({substruct: 1});

query ??
select *
Expand All @@ -615,7 +617,7 @@ statement ok
DROP TABLE test;

statement ok
create table test (struct_field struct(substruct struct(subsubstruct int))) as values (struct(struct(1)));
create table test (struct_field struct(substruct struct(subsubstruct int))) as values ({substruct: {subsubstruct: 1}});

query ??
select *
Expand Down Expand Up @@ -823,9 +825,9 @@ SELECT CAST({b: 3, a: 4} AS STRUCT(a BIGINT, b INT));
----
{a: 4, b: 3}

# Test positional casting when there is no name overlap
# Test casting with explicit field names
query ?
SELECT CAST(struct(1, 'x') AS STRUCT(a INT, b VARCHAR));
SELECT CAST({a: 1, b: 'x'} AS STRUCT(a INT, b VARCHAR));
----
{a: 1, b: x}

Expand Down Expand Up @@ -859,9 +861,9 @@ statement ok
CREATE TABLE struct_reorder_test (
data STRUCT(b INT, a VARCHAR)
) AS VALUES
(struct(100, 'first')),
(struct(200, 'second')),
(struct(300, 'third'))
({b: 100, a: 'first'}),
({b: 200, a: 'second'}),
({b: 300, a: 'third'})
;

query ?
Expand Down
Loading