Skip to content

Commit 7d69f80

Browse files
timsaucerclaude
andcommitted
feat: wire rex_type and rex_call_operands for new Expr variants
Map HigherOrderFunction and Lambda to RexType::Call; LambdaVariable to RexType::Reference. In rex_call_operands return the args for HigherOrderFunction, the body for Lambda, and self for LambdaVariable (mirroring Column). In rex_call_operator return the underlying UDF name for HigherOrderFunction and the literal "lambda" for Lambda. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 75f7f7b commit 7d69f80

1 file changed

Lines changed: 17 additions & 18 deletions

File tree

crates/core/src/expr.rs

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ use datafusion::arrow::datatypes::{DataType, Field};
2323
use datafusion::arrow::pyarrow::PyArrowType;
2424
use datafusion::functions::core::expr_ext::FieldAccessor;
2525
use datafusion::logical_expr::expr::{
26-
AggregateFunction, AggregateFunctionParams, FieldMetadata, InList, InSubquery, ScalarFunction,
27-
SetComparison, WindowFunction,
26+
AggregateFunction, AggregateFunctionParams, FieldMetadata, HigherOrderFunction, InList,
27+
InSubquery, Lambda, ScalarFunction, SetComparison, WindowFunction,
2828
};
2929
use datafusion::logical_expr::utils::exprlist_to_fields;
3030
use datafusion::logical_expr::{
@@ -398,17 +398,15 @@ impl PyExpr {
398398
| Expr::OuterReferenceColumn(_, _)
399399
| Expr::Unnest(_)
400400
| Expr::IsNotUnknown(_)
401-
| Expr::SetComparison(_) => RexType::Call,
401+
| Expr::SetComparison(_)
402+
| Expr::HigherOrderFunction(..)
403+
| Expr::Lambda(..) => RexType::Call,
404+
Expr::LambdaVariable(..) => RexType::Reference,
402405
Expr::ScalarSubquery(..) => RexType::ScalarSubquery,
403406
#[allow(deprecated)]
404407
Expr::Wildcard { .. } => {
405408
return Err(py_unsupported_variant_err("Expr::Wildcard is unsupported"));
406409
}
407-
Expr::HigherOrderFunction(..) | Expr::Lambda(..) | Expr::LambdaVariable(..) => {
408-
return Err(py_unsupported_variant_err(
409-
"Expr::HigherOrderFunction / Lambda / LambdaVariable is unsupported",
410-
));
411-
}
412410
})
413411
}
414412

@@ -435,9 +433,10 @@ impl PyExpr {
435433
pub fn rex_call_operands(&self) -> PyResult<Vec<PyExpr>> {
436434
match &self.expr {
437435
// Expr variants that are themselves the operand to return
438-
Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => {
439-
Ok(vec![PyExpr::from(self.expr.clone())])
440-
}
436+
Expr::Column(..)
437+
| Expr::ScalarVariable(..)
438+
| Expr::Literal(..)
439+
| Expr::LambdaVariable(..) => Ok(vec![PyExpr::from(self.expr.clone())]),
441440

442441
Expr::Alias(alias) => Ok(vec![PyExpr::from(*alias.expr.clone())]),
443442

@@ -464,13 +463,15 @@ impl PyExpr {
464463
params: AggregateFunctionParams { args, .. },
465464
..
466465
})
467-
| Expr::ScalarFunction(ScalarFunction { args, .. }) => {
466+
| Expr::ScalarFunction(ScalarFunction { args, .. })
467+
| Expr::HigherOrderFunction(HigherOrderFunction { args, .. }) => {
468468
Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect())
469469
}
470470
Expr::WindowFunction(boxed_window_fn) => {
471471
let args = &boxed_window_fn.params.args;
472472
Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect())
473473
}
474+
Expr::Lambda(Lambda { body, .. }) => Ok(vec![PyExpr::from(*body.clone())]),
474475

475476
// Expr(s) that require more specific processing
476477
Expr::Case(Case {
@@ -548,12 +549,6 @@ impl PyExpr {
548549
Expr::Wildcard { .. } => {
549550
Err(py_unsupported_variant_err("Expr::Wildcard is unsupported"))
550551
}
551-
552-
Expr::HigherOrderFunction(..) | Expr::Lambda(..) | Expr::LambdaVariable(..) => {
553-
Err(py_unsupported_variant_err(
554-
"Expr::HigherOrderFunction / Lambda / LambdaVariable is unsupported",
555-
))
556-
}
557552
}
558553
}
559554

@@ -566,6 +561,10 @@ impl PyExpr {
566561
right: _,
567562
}) => format!("{op}"),
568563
Expr::ScalarFunction(ScalarFunction { func, args: _ }) => func.name().to_string(),
564+
Expr::HigherOrderFunction(HigherOrderFunction { func, args: _ }) => {
565+
func.name().to_string()
566+
}
567+
Expr::Lambda(..) => "lambda".to_string(),
569568
Expr::Cast { .. } => "cast".to_string(),
570569
Expr::Between { .. } => "between".to_string(),
571570
Expr::Case { .. } => "case".to_string(),

0 commit comments

Comments
 (0)