Skip to content

Commit 4b51402

Browse files
timsaucerclaude
andcommitted
feat: inline encoding for Python aggregate UDFs
Aggregate UDFs no longer need worker-side pre-registration. The codec serializes the Python accumulator factory + state schema into the wire format and the receiver reconstructs the UDF from bytes alone. New `PythonFunctionAggregateUDF` named struct (in `crates/core/src/udaf.rs`) holds `accumulator: Py<PyAny>` plus signature, return type, and state fields directly. Full `AggregateUDFImpl` impl mirroring upstream `SimpleAggregateUDF`: `as_any`, `name`, `signature`, `return_type`, `accumulator`, `state_fields`. `accumulator()` lazily instantiates a fresh accumulator per partition via the new `instantiate_accumulator()` helper. `PyAggregateUDF::new` now constructs `PythonFunctionAggregateUDF` directly via `AggregateUDF::new_from_impl(...)` instead of routing through `create_udaf(...)` + `to_rust_accumulator(...)`. The closure- based factory path is gone; the Python state stays addressable. Codec wiring: * `crates/core/src/codec.rs` adds `try_encode_python_agg_udf` / `try_decode_python_agg_udf` plus the `DFPYUDA1` magic prefix. Tuple shape: `(name, accumulator, input_schema_bytes, return_schema_bytes, state_schema_bytes, volatility_str)`. * `PythonLogicalCodec.try_encode_udaf` / `try_decode_udaf` and the matching `PythonPhysicalCodec` methods consult the helpers first and fall back to `inner` for non-Python aggregate UDFs. Test coverage in `test_pickle_expr.py::TestAggregateUDFCodec` mirrors the scalar / window UDF cases. 1094 root tests pass (up from 1088, plus 3 new UDAF cases and 3 new UDWF cases from the prior commit). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 89d119f commit 4b51402

3 files changed

Lines changed: 367 additions & 18 deletions

File tree

crates/core/src/codec.rs

Lines changed: 166 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ use datafusion::datasource::TableProvider;
8585
use datafusion::datasource::file_format::FileFormatFactory;
8686
use datafusion::execution::TaskContext;
8787
use datafusion::logical_expr::{
88-
AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, TypeSignature, WindowUDF,
89-
WindowUDFImpl,
88+
AggregateUDF, AggregateUDFImpl, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl,
89+
TypeSignature, WindowUDF, WindowUDFImpl,
9090
};
9191
use datafusion::physical_expr::PhysicalExpr;
9292
use datafusion::physical_plan::ExecutionPlan;
@@ -95,6 +95,7 @@ use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExt
9595
use pyo3::prelude::*;
9696
use pyo3::types::{PyBytes, PyTuple};
9797

98+
use crate::udaf::PythonFunctionAggregateUDF;
9899
use crate::udf::PythonFunctionScalarUDF;
99100
use crate::udwf::MultiColumnWindowUDF;
100101

@@ -104,6 +105,11 @@ use crate::udwf::MultiColumnWindowUDF;
104105
/// the encoder and decoder cannot drift.
105106
pub(crate) const PY_SCALAR_UDF_MAGIC: &[u8] = b"DFPYUDF1";
106107

108+
/// Wire-format prefix for an inlined Python aggregate UDF
109+
/// (cloudpickled tuple of name, accumulator factory, input schema,
110+
/// return type, state types schema, volatility).
111+
pub(crate) const PY_AGG_UDF_MAGIC: &[u8] = b"DFPYUDA1";
112+
107113
/// Wire-format prefix for an inlined Python window UDF (cloudpickled
108114
/// tuple of name, evaluator factory, input schema, return type,
109115
/// volatility).
@@ -205,10 +211,16 @@ impl LogicalExtensionCodec for PythonLogicalCodec {
205211
}
206212

207213
fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> {
214+
if try_encode_python_agg_udf(node, buf)? {
215+
return Ok(());
216+
}
208217
self.inner.try_encode_udaf(node, buf)
209218
}
210219

211220
fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> {
221+
if let Some(udaf) = try_decode_python_agg_udf(buf)? {
222+
return Ok(udaf);
223+
}
212224
self.inner.try_decode_udaf(name, buf)
213225
}
214226

@@ -301,10 +313,16 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec {
301313
}
302314

303315
fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> {
316+
if try_encode_python_agg_udf(node, buf)? {
317+
return Ok(());
318+
}
304319
self.inner.try_encode_udaf(node, buf)
305320
}
306321

307322
fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> {
323+
if let Some(udaf) = try_decode_python_agg_udf(buf)? {
324+
return Ok(udaf);
325+
}
308326
self.inner.try_decode_udaf(name, buf)
309327
}
310328

@@ -613,3 +631,149 @@ fn decode_python_window_udf(py: Python<'_>, payload: &[u8]) -> PyResult<MultiCol
613631
volatility,
614632
))
615633
}
634+
635+
// =============================================================================
636+
// Shared Python aggregate UDF encode / decode helpers
637+
//
638+
// Cloudpickle tuple shape: `(name, accumulator_factory, input_schema_bytes,
639+
// return_type_bytes, state_schema_bytes, volatility_str)`. The accumulator
640+
// factory is the Python callable that produces a new accumulator instance
641+
// per partition.
642+
// =============================================================================
643+
644+
pub(crate) fn try_encode_python_agg_udf(node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<bool> {
645+
let Some(py_udf) = node
646+
.inner()
647+
.as_any()
648+
.downcast_ref::<PythonFunctionAggregateUDF>()
649+
else {
650+
return Ok(false);
651+
};
652+
653+
Python::attach(|py| -> Result<bool> {
654+
let bytes = encode_python_agg_udf(py, py_udf)
655+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
656+
buf.extend_from_slice(PY_AGG_UDF_MAGIC);
657+
buf.extend_from_slice(&bytes);
658+
Ok(true)
659+
})
660+
}
661+
662+
pub(crate) fn try_decode_python_agg_udf(buf: &[u8]) -> Result<Option<Arc<AggregateUDF>>> {
663+
if buf.is_empty() || !buf.starts_with(PY_AGG_UDF_MAGIC) {
664+
return Ok(None);
665+
}
666+
let payload = &buf[PY_AGG_UDF_MAGIC.len()..];
667+
668+
Python::attach(|py| -> Result<Option<Arc<AggregateUDF>>> {
669+
let udf = decode_python_agg_udf(py, payload)
670+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
671+
Ok(Some(Arc::new(AggregateUDF::new_from_impl(udf))))
672+
})
673+
}
674+
675+
fn encode_python_agg_udf(py: Python<'_>, udf: &PythonFunctionAggregateUDF) -> PyResult<Vec<u8>> {
676+
let cloudpickle = py.import("cloudpickle")?;
677+
678+
let signature = AggregateUDFImpl::signature(udf);
679+
let input_dtypes: Vec<arrow::datatypes::DataType> = match &signature.type_signature {
680+
TypeSignature::Exact(types) => types.clone(),
681+
other => {
682+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
683+
"PythonFunctionAggregateUDF expected Signature::Exact, got {other:?}"
684+
)));
685+
}
686+
};
687+
let input_fields: Vec<Field> = input_dtypes
688+
.into_iter()
689+
.enumerate()
690+
.map(|(i, dt)| Field::new(format!("arg_{i}"), dt, true))
691+
.collect();
692+
let input_schema_bytes = schema_to_ipc_bytes(&Schema::new(input_fields))
693+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
694+
695+
let return_schema = Schema::new(vec![Field::new("result", udf.return_type().clone(), true)]);
696+
let return_schema_bytes = schema_to_ipc_bytes(&return_schema)
697+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
698+
699+
let state_fields: Vec<Field> = udf
700+
.state_fields_ref()
701+
.iter()
702+
.map(|f| f.as_ref().clone())
703+
.collect();
704+
let state_schema_bytes = schema_to_ipc_bytes(&Schema::new(state_fields))
705+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
706+
707+
let volatility = format!("{:?}", signature.volatility).to_lowercase();
708+
709+
let payload = PyTuple::new(
710+
py,
711+
[
712+
AggregateUDFImpl::name(udf).into_pyobject(py)?.into_any(),
713+
udf.accumulator().bind(py).clone().into_any(),
714+
PyBytes::new(py, &input_schema_bytes).into_any(),
715+
PyBytes::new(py, &return_schema_bytes).into_any(),
716+
PyBytes::new(py, &state_schema_bytes).into_any(),
717+
volatility.into_pyobject(py)?.into_any(),
718+
],
719+
)?;
720+
721+
let blob = cloudpickle.call_method1("dumps", (payload,))?;
722+
blob.extract::<Vec<u8>>()
723+
}
724+
725+
fn decode_python_agg_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionAggregateUDF> {
726+
let cloudpickle = py.import("cloudpickle")?;
727+
728+
let tuple = cloudpickle
729+
.call_method1("loads", (PyBytes::new(py, payload),))?
730+
.cast_into::<PyTuple>()?;
731+
732+
let name: String = tuple.get_item(0)?.extract()?;
733+
let accumulator: Py<PyAny> = tuple.get_item(1)?.unbind();
734+
let input_schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?;
735+
let return_schema_bytes: Vec<u8> = tuple.get_item(3)?.extract()?;
736+
let state_schema_bytes: Vec<u8> = tuple.get_item(4)?.extract()?;
737+
let volatility_str: String = tuple.get_item(5)?.extract()?;
738+
739+
let input_schema = schema_from_ipc_bytes(&input_schema_bytes)
740+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
741+
let input_types: Vec<arrow::datatypes::DataType> = input_schema
742+
.fields()
743+
.iter()
744+
.map(|f| f.data_type().clone())
745+
.collect();
746+
747+
let return_schema = schema_from_ipc_bytes(&return_schema_bytes)
748+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
749+
let return_type = return_schema
750+
.fields()
751+
.first()
752+
.ok_or_else(|| {
753+
pyo3::exceptions::PyValueError::new_err(
754+
"PythonFunctionAggregateUDF return schema must contain exactly one field",
755+
)
756+
})?
757+
.data_type()
758+
.clone();
759+
760+
let state_schema = schema_from_ipc_bytes(&state_schema_bytes)
761+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
762+
let state_types: Vec<arrow::datatypes::DataType> = state_schema
763+
.fields()
764+
.iter()
765+
.map(|f| f.data_type().clone())
766+
.collect();
767+
768+
let volatility = datafusion_python_util::parse_volatility(&volatility_str)
769+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
770+
771+
Ok(PythonFunctionAggregateUDF::from_parts(
772+
name,
773+
accumulator,
774+
input_types,
775+
return_type,
776+
state_types,
777+
volatility,
778+
))
779+
}

0 commit comments

Comments
 (0)