Skip to content

Commit a261445

Browse files
committed
Null checks
1 parent 23d39d5 commit a261445

2 files changed

Lines changed: 246 additions & 23 deletions

File tree

native/spark-expr/src/array_funcs/arrays_zip.rs

Lines changed: 220 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,22 @@
1616
// under the License.
1717

1818
use arrow::array::RecordBatch;
19-
use arrow::array::{Array, ArrayRef, StringArray};
19+
use arrow::array::{
20+
new_null_array, Array, ArrayRef, Capacities, ListArray, MutableArrayData, StructArray,
21+
};
22+
use arrow::buffer::{NullBuffer, OffsetBuffer};
2023
use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
2124
use arrow::datatypes::Schema;
2225
use arrow::datatypes::{DataType, Field, Fields};
26+
use datafusion::common::cast::{as_fixed_size_list_array, as_large_list_array, as_list_array};
2327
use datafusion::common::{exec_err, Result, ScalarValue};
2428
use datafusion::logical_expr::ColumnarValue;
2529
use datafusion::physical_expr::PhysicalExpr;
26-
use datafusion::functions_nested::arrays_zip::arrays_zip_inner;
2730
use std::any::Any;
2831
use std::fmt::{Display, Formatter};
2932
use std::sync::Arc;
33+
// use datafusion::functions_nested::utils::make_scalar_function;
34+
// use datafusion::functions_nested::arrays_zip::arrays_zip_inner;
3035

3136
#[derive(Debug, Eq, Hash, PartialEq)]
3237
pub struct SparkArraysZipFunc {
@@ -81,7 +86,7 @@ impl PhysicalExpr for SparkArraysZipFunc {
8186
}
8287

8388
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
84-
Ok(false)
89+
Ok(true)
8590
}
8691

8792
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
@@ -91,7 +96,36 @@ impl PhysicalExpr for SparkArraysZipFunc {
9196
.map(|e| e.evaluate(batch))
9297
.collect::<datafusion::common::Result<Vec<_>>>()?;
9398

94-
let len = values
99+
make_scalar_function(|arr| arrays_zip_inner(arr, self.names.clone()))(&*values)
100+
}
101+
102+
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
103+
self.values.iter().collect()
104+
}
105+
106+
fn with_new_children(
107+
self: Arc<Self>,
108+
children: Vec<Arc<dyn PhysicalExpr>>,
109+
) -> Result<Arc<dyn PhysicalExpr>> {
110+
Ok(Arc::new(SparkArraysZipFunc::new(
111+
children.clone(),
112+
self.names.clone(),
113+
)))
114+
}
115+
116+
fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
117+
Display::fmt(self, f)
118+
}
119+
}
120+
121+
pub fn make_scalar_function<F>(inner: F) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
122+
where
123+
F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
124+
{
125+
move |args: &[ColumnarValue]| {
126+
// first, identify if any of the arguments is an Array. If yes, store its `len`,
127+
// as any scalar will need to be converted to an array of len `len`.
128+
let len = args
95129
.iter()
96130
.fold(Option::<usize>::None, |acc, arg| match arg {
97131
ColumnarValue::Scalar(_) => acc,
@@ -100,11 +134,9 @@ impl PhysicalExpr for SparkArraysZipFunc {
100134

101135
let is_scalar = len.is_none();
102136

103-
let arrays = ColumnarValue::values_to_arrays(&values)?;
104-
let names = vec![Arc::new(StringArray::from(self.names.clone())) as ArrayRef];
137+
let args = ColumnarValue::values_to_arrays(args)?;
105138

106-
// TODO: replace this with DF's function
107-
let result = arrays_zip_inner(&arrays, &names);
139+
let result = (inner)(&args);
108140

109141
if is_scalar {
110142
// If all inputs are scalar, keeps output as scalar
@@ -114,22 +146,190 @@ impl PhysicalExpr for SparkArraysZipFunc {
114146
result.map(ColumnarValue::Array)
115147
}
116148
}
149+
}
117150

118-
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
119-
self.values.iter().collect()
151+
struct ListColumnView {
152+
/// The flat values array backing this list column.
153+
values: ArrayRef,
154+
/// Pre-computed per-row start offsets (length = num_rows + 1).
155+
offsets: Vec<usize>,
156+
/// Pre-computed null bitmap: true means the row is null.
157+
is_null: Vec<bool>,
158+
}
159+
160+
pub fn arrays_zip_inner(args: &[ArrayRef], names: Vec<String>) -> Result<ArrayRef> {
161+
if args.is_empty() {
162+
return exec_err!("arrays_zip requires at least one argument");
120163
}
121164

122-
fn with_new_children(
123-
self: Arc<Self>,
124-
children: Vec<Arc<dyn PhysicalExpr>>,
125-
) -> Result<Arc<dyn PhysicalExpr>> {
126-
Ok(Arc::new(SparkArraysZipFunc::new(
127-
children.clone(),
128-
self.names.clone(),
129-
)))
165+
let num_rows = args[0].len();
166+
167+
// Build a type-erased ListColumnView for each argument.
168+
// None means the argument is Null-typed (all nulls, no backing data).
169+
let mut views: Vec<Option<ListColumnView>> = Vec::with_capacity(args.len());
170+
let mut element_types: Vec<DataType> = Vec::with_capacity(args.len());
171+
172+
for (i, arg) in args.iter().enumerate() {
173+
match arg.data_type() {
174+
List(field) => {
175+
let arr = as_list_array(arg)?;
176+
let raw_offsets = arr.value_offsets();
177+
let offsets: Vec<usize> = raw_offsets.iter().map(|&o| o as usize).collect();
178+
let is_null = (0..num_rows).map(|row| arr.is_null(row)).collect();
179+
element_types.push(field.data_type().clone());
180+
views.push(Some(ListColumnView {
181+
values: Arc::clone(arr.values()),
182+
offsets,
183+
is_null,
184+
}));
185+
}
186+
LargeList(field) => {
187+
let arr = as_large_list_array(arg)?;
188+
let raw_offsets = arr.value_offsets();
189+
let offsets: Vec<usize> = raw_offsets.iter().map(|&o| o as usize).collect();
190+
let is_null = (0..num_rows).map(|row| arr.is_null(row)).collect();
191+
element_types.push(field.data_type().clone());
192+
views.push(Some(ListColumnView {
193+
values: Arc::clone(arr.values()),
194+
offsets,
195+
is_null,
196+
}));
197+
}
198+
FixedSizeList(field, size) => {
199+
let arr = as_fixed_size_list_array(arg)?;
200+
let size = *size as usize;
201+
let offsets: Vec<usize> = (0..=num_rows).map(|row| row * size).collect();
202+
let is_null = (0..num_rows).map(|row| arr.is_null(row)).collect();
203+
element_types.push(field.data_type().clone());
204+
views.push(Some(ListColumnView {
205+
values: Arc::clone(arr.values()),
206+
offsets,
207+
is_null,
208+
}));
209+
}
210+
Null => {
211+
element_types.push(Null);
212+
views.push(None);
213+
}
214+
dt => {
215+
return exec_err!("arrays_zip argument {i} expected list type, got {dt}");
216+
}
217+
}
130218
}
131219

132-
fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
133-
Display::fmt(self, f)
220+
// Collect per-column values data for MutableArrayData builders.
221+
let values_data: Vec<_> = views
222+
.iter()
223+
.map(|v| v.as_ref().map(|view| view.values.to_data()))
224+
.collect();
225+
226+
let struct_fields: Fields = element_types
227+
.iter()
228+
.enumerate()
229+
.map(|(i, dt)| Field::new(names[i].to_string(), dt.clone(), true))
230+
.collect::<Vec<_>>()
231+
.into();
232+
233+
// Create a MutableArrayData builder per column. For None (Null-typed)
234+
// args we only need extend_nulls, so we track them separately.
235+
let mut builders: Vec<Option<MutableArrayData>> = values_data
236+
.iter()
237+
.map(|vd| {
238+
vd.as_ref().map(|data| {
239+
MutableArrayData::with_capacities(vec![data], true, Capacities::Array(0))
240+
})
241+
})
242+
.collect();
243+
244+
let mut offsets: Vec<i32> = Vec::with_capacity(num_rows + 1);
245+
offsets.push(0);
246+
let mut null_mask: Vec<bool> = Vec::with_capacity(num_rows);
247+
let mut total_values: usize = 0;
248+
249+
// Process each row: compute per-array lengths, then copy values
250+
// and pad shorter arrays with NULLs.
251+
for row_idx in 0..num_rows {
252+
let mut max_len: usize = 0;
253+
let mut all_null = true;
254+
255+
for view in views.iter().flatten() {
256+
if !view.is_null[row_idx] {
257+
all_null = false;
258+
let len = view.offsets[row_idx + 1] - view.offsets[row_idx];
259+
max_len = max_len.max(len);
260+
}
261+
}
262+
263+
if all_null {
264+
null_mask.push(true);
265+
offsets.push(*offsets.last().unwrap());
266+
continue;
267+
}
268+
null_mask.push(false);
269+
270+
// Extend each column builder for this row.
271+
for (col_idx, view) in views.iter().enumerate() {
272+
match view {
273+
Some(v) if !v.is_null[row_idx] => {
274+
let start = v.offsets[row_idx];
275+
let end = v.offsets[row_idx + 1];
276+
let len = end - start;
277+
let builder = builders[col_idx].as_mut().unwrap();
278+
builder.extend(0, start, end);
279+
if len < max_len {
280+
builder.extend_nulls(max_len - len);
281+
}
282+
}
283+
_ => {
284+
// Null list entry or None (Null-typed) arg — all nulls.
285+
if let Some(builder) = builders[col_idx].as_mut() {
286+
builder.extend_nulls(max_len);
287+
}
288+
}
289+
}
290+
}
291+
292+
total_values += max_len;
293+
let last = *offsets.last().unwrap();
294+
offsets.push(last + max_len as i32);
134295
}
296+
297+
// Assemble struct columns from builders.
298+
let struct_columns: Vec<ArrayRef> = builders
299+
.into_iter()
300+
.zip(element_types.iter())
301+
.map(|(builder, elem_type)| match builder {
302+
Some(b) => arrow::array::make_array(b.freeze()),
303+
None => new_null_array(
304+
if elem_type.is_null() {
305+
&Null
306+
} else {
307+
elem_type
308+
},
309+
total_values,
310+
),
311+
})
312+
.collect();
313+
314+
let struct_array = StructArray::try_new(struct_fields, struct_columns, None)?;
315+
316+
let null_buffer = if null_mask.iter().any(|&v| v) {
317+
Some(NullBuffer::from(
318+
null_mask.iter().map(|v| !v).collect::<Vec<bool>>(),
319+
))
320+
} else {
321+
None
322+
};
323+
324+
let result = ListArray::try_new(
325+
Arc::new(Field::new_list_field(
326+
struct_array.data_type().clone(),
327+
true,
328+
)),
329+
OffsetBuffer::new(offsets.into()),
330+
Arc::new(struct_array),
331+
null_buffer,
332+
)?;
333+
334+
Ok(Arc::new(result))
135335
}

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -695,13 +695,36 @@ object CometArraysZip extends CometExpressionSerde[ArraysZip] {
695695
val exprChildren = expr.children.map(exprToProtoInternal(_, inputs, binding))
696696
val names = expr.names.map(_.eval(EmptyRow))
697697

698-
if (exprChildren.forall(_.isDefined)) {
699-
val builder = ExprOuterClass.ArraysZip
698+
val isNotNullExpr = expr.children
699+
.map(child =>
700+
createUnaryExpr(
701+
expr,
702+
child,
703+
inputs,
704+
binding,
705+
(builder, unaryExpr) => builder.setIsNotNull(unaryExpr)))
706+
val nullLiteralProto = exprToProto(Literal(null, BooleanType), Seq.empty)
707+
708+
if (exprChildren.forall(_.isDefined) && isNotNullExpr.forall(expr =>
709+
expr.isDefined) && nullLiteralProto.isDefined) {
710+
val arraysZip: ExprOuterClass.ArraysZip = ExprOuterClass.ArraysZip
700711
.newBuilder()
701712
.addAllValues(exprChildren.map(_.get).asJava)
702713
.addAllNames(names.map(_.toString).asJava)
714+
.build()
715+
716+
val caseWhenExpr = ExprOuterClass.CaseWhen
717+
.newBuilder()
718+
.addAllWhen(isNotNullExpr.map(_.get).asJava)
719+
.addThen(ExprOuterClass.Expr.newBuilder().setArraysZip(arraysZip).build())
720+
.setElseExpr(nullLiteralProto.get)
721+
.build()
722+
Some(
723+
ExprOuterClass.Expr
724+
.newBuilder()
725+
.setCaseWhen(caseWhenExpr)
726+
.build())
703727

704-
Some(ExprOuterClass.Expr.newBuilder().setArraysZip(builder).build())
705728
} else {
706729
withInfo(expr, "unsupported arguments for ArraysZip", expr.children ++ expr.names: _*)
707730
None

0 commit comments

Comments
 (0)