Skip to content

Commit 89d119f

Browse files
timsaucerclaude
andcommitted
feat: inline encoding for Python window UDFs
Window UDFs no longer need worker-side pre-registration. The codec serializes the Python evaluator factory into the wire format and the receiver reconstructs the UDF from bytes alone, same as scalar UDFs. Refactor `MultiColumnWindowUDF` to store the Python evaluator callable directly (`evaluator: Py<PyAny>`) instead of a `PartitionEvaluatorFactory` closure. The factory closure was a boxed `Fn` that captured the Python state opaquely, with nothing for the codec to downcast back to. Now the named struct holds the `Py<PyAny>` and builds a partition evaluator inside `partition_evaluator()` on demand. `PyWindowUDF::new` constructs `MultiColumnWindowUDF` directly with the evaluator. `to_rust_partition_evaluator` is replaced by `instantiate_partition_evaluator`, called from the trait method. Codec wiring: * `crates/core/src/codec.rs` adds `try_encode_python_window_udf` / `try_decode_python_window_udf` plus the `DFPYUDW1` magic prefix. * `PythonLogicalCodec.try_encode_udwf` / `try_decode_udwf` and the matching `PythonPhysicalCodec` methods consult the helpers first and fall back to `inner` for non-Python window UDFs. Test coverage in `test_pickle_expr.py::TestWindowUDFCodec` mirrors the scalar UDF cases: self-contained blob, decode into fresh context, decode via pickle with no worker context. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent cc5ce7e commit 89d119f

3 files changed

Lines changed: 248 additions & 30 deletions

File tree

crates/core/src/codec.rs

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ use datafusion::datasource::file_format::FileFormatFactory;
8686
use datafusion::execution::TaskContext;
8787
use datafusion::logical_expr::{
8888
AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, TypeSignature, WindowUDF,
89+
WindowUDFImpl,
8990
};
9091
use datafusion::physical_expr::PhysicalExpr;
9192
use datafusion::physical_plan::ExecutionPlan;
@@ -95,13 +96,19 @@ use pyo3::prelude::*;
9596
use pyo3::types::{PyBytes, PyTuple};
9697

9798
use crate::udf::PythonFunctionScalarUDF;
99+
use crate::udwf::MultiColumnWindowUDF;
98100

99101
/// Wire-format prefix that tags a `fun_definition` payload as an
100102
/// inlined Python scalar UDF (cloudpickled tuple of name, callable,
101103
/// input schema, return field, volatility). Defined once here so
102104
/// the encoder and decoder cannot drift.
103105
pub(crate) const PY_SCALAR_UDF_MAGIC: &[u8] = b"DFPYUDF1";
104106

107+
/// Wire-format prefix for an inlined Python window UDF (cloudpickled
108+
/// tuple of name, evaluator factory, input schema, return type,
109+
/// volatility).
110+
pub(crate) const PY_WINDOW_UDF_MAGIC: &[u8] = b"DFPYUDW1";
111+
105112
/// `LogicalExtensionCodec` parked on every `SessionContext`. Holds
106113
/// the Python-aware encoding hooks for logical-layer types
107114
/// (`LogicalPlan`, `Expr`) and delegates everything it does not
@@ -206,10 +213,16 @@ impl LogicalExtensionCodec for PythonLogicalCodec {
206213
}
207214

208215
fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()> {
216+
if try_encode_python_window_udf(node, buf)? {
217+
return Ok(());
218+
}
209219
self.inner.try_encode_udwf(node, buf)
210220
}
211221

212222
fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result<Arc<WindowUDF>> {
223+
if let Some(udwf) = try_decode_python_window_udf(buf)? {
224+
return Ok(udwf);
225+
}
213226
self.inner.try_decode_udwf(name, buf)
214227
}
215228
}
@@ -296,10 +309,16 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec {
296309
}
297310

298311
fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()> {
312+
if try_encode_python_window_udf(node, buf)? {
313+
return Ok(());
314+
}
299315
self.inner.try_encode_udwf(node, buf)
300316
}
301317

302318
fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result<Arc<WindowUDF>> {
319+
if let Some(udwf) = try_decode_python_window_udf(buf)? {
320+
return Ok(udwf);
321+
}
303322
self.inner.try_decode_udwf(name, buf)
304323
}
305324
}
@@ -471,3 +490,126 @@ fn schema_from_ipc_bytes(bytes: &[u8]) -> arrow::error::Result<Schema> {
471490
let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?;
472491
Ok(reader.schema().as_ref().clone())
473492
}
493+
494+
// =============================================================================
495+
// Shared Python window UDF encode / decode helpers
496+
//
497+
// Cloudpickle tuple shape: `(name, evaluator_factory, input_schema_bytes,
498+
// return_schema_bytes, volatility_str)`. The evaluator factory is the
499+
// Python callable that produces a new evaluator instance per partition.
500+
// =============================================================================
501+
502+
pub(crate) fn try_encode_python_window_udf(node: &WindowUDF, buf: &mut Vec<u8>) -> Result<bool> {
503+
let Some(py_udf) = node.inner().as_any().downcast_ref::<MultiColumnWindowUDF>() else {
504+
return Ok(false);
505+
};
506+
507+
Python::attach(|py| -> Result<bool> {
508+
let bytes = encode_python_window_udf(py, py_udf)
509+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
510+
buf.extend_from_slice(PY_WINDOW_UDF_MAGIC);
511+
buf.extend_from_slice(&bytes);
512+
Ok(true)
513+
})
514+
}
515+
516+
pub(crate) fn try_decode_python_window_udf(buf: &[u8]) -> Result<Option<Arc<WindowUDF>>> {
517+
if buf.is_empty() || !buf.starts_with(PY_WINDOW_UDF_MAGIC) {
518+
return Ok(None);
519+
}
520+
let payload = &buf[PY_WINDOW_UDF_MAGIC.len()..];
521+
522+
Python::attach(|py| -> Result<Option<Arc<WindowUDF>>> {
523+
let udf = decode_python_window_udf(py, payload)
524+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
525+
Ok(Some(Arc::new(WindowUDF::new_from_impl(udf))))
526+
})
527+
}
528+
529+
fn encode_python_window_udf(py: Python<'_>, udf: &MultiColumnWindowUDF) -> PyResult<Vec<u8>> {
530+
let cloudpickle = py.import("cloudpickle")?;
531+
532+
let signature = WindowUDFImpl::signature(udf);
533+
let input_dtypes: Vec<arrow::datatypes::DataType> = match &signature.type_signature {
534+
TypeSignature::Exact(types) => types.clone(),
535+
other => {
536+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
537+
"MultiColumnWindowUDF expected Signature::Exact, got {other:?}"
538+
)));
539+
}
540+
};
541+
let input_fields: Vec<Field> = input_dtypes
542+
.into_iter()
543+
.enumerate()
544+
.map(|(i, dt)| Field::new(format!("arg_{i}"), dt, true))
545+
.collect();
546+
let input_schema = Schema::new(input_fields);
547+
let input_schema_bytes = schema_to_ipc_bytes(&input_schema)
548+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
549+
550+
let return_schema = Schema::new(vec![Field::new("result", udf.return_type().clone(), true)]);
551+
let return_schema_bytes = schema_to_ipc_bytes(&return_schema)
552+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
553+
554+
let volatility = format!("{:?}", signature.volatility).to_lowercase();
555+
556+
let payload = PyTuple::new(
557+
py,
558+
[
559+
WindowUDFImpl::name(udf).into_pyobject(py)?.into_any(),
560+
udf.evaluator().bind(py).clone().into_any(),
561+
PyBytes::new(py, &input_schema_bytes).into_any(),
562+
PyBytes::new(py, &return_schema_bytes).into_any(),
563+
volatility.into_pyobject(py)?.into_any(),
564+
],
565+
)?;
566+
567+
let blob = cloudpickle.call_method1("dumps", (payload,))?;
568+
blob.extract::<Vec<u8>>()
569+
}
570+
571+
fn decode_python_window_udf(py: Python<'_>, payload: &[u8]) -> PyResult<MultiColumnWindowUDF> {
572+
let cloudpickle = py.import("cloudpickle")?;
573+
574+
let tuple = cloudpickle
575+
.call_method1("loads", (PyBytes::new(py, payload),))?
576+
.cast_into::<PyTuple>()?;
577+
578+
let name: String = tuple.get_item(0)?.extract()?;
579+
let evaluator: Py<PyAny> = tuple.get_item(1)?.unbind();
580+
let input_schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?;
581+
let return_schema_bytes: Vec<u8> = tuple.get_item(3)?.extract()?;
582+
let volatility_str: String = tuple.get_item(4)?.extract()?;
583+
584+
let input_schema = schema_from_ipc_bytes(&input_schema_bytes)
585+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
586+
let input_types: Vec<arrow::datatypes::DataType> = input_schema
587+
.fields()
588+
.iter()
589+
.map(|f| f.data_type().clone())
590+
.collect();
591+
592+
let return_schema = schema_from_ipc_bytes(&return_schema_bytes)
593+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
594+
let return_type = return_schema
595+
.fields()
596+
.first()
597+
.ok_or_else(|| {
598+
pyo3::exceptions::PyValueError::new_err(
599+
"MultiColumnWindowUDF return schema must contain exactly one field",
600+
)
601+
})?
602+
.data_type()
603+
.clone();
604+
605+
let volatility = datafusion_python_util::parse_volatility(&volatility_str)
606+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
607+
608+
Ok(MultiColumnWindowUDF::from_parts(
609+
name,
610+
evaluator,
611+
input_types,
612+
return_type,
613+
volatility,
614+
))
615+
}

crates/core/src/udwf.rs

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ use datafusion::arrow::datatypes::DataType;
2525
use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow};
2626
use datafusion::error::{DataFusionError, Result};
2727
use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs};
28-
use datafusion::logical_expr::ptr_eq::PtrEq;
2928
use datafusion::logical_expr::window_state::WindowAggState;
3029
use datafusion::logical_expr::{
31-
PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl,
30+
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
3231
};
3332
use datafusion::scalar::ScalarValue;
3433
use datafusion_ffi::udwf::FFI_WindowUDF;
@@ -198,15 +197,13 @@ impl PartitionEvaluator for RustPartitionEvaluator {
198197
}
199198
}
200199

201-
pub fn to_rust_partition_evaluator(evaluator: Py<PyAny>) -> PartitionEvaluatorFactory {
202-
Arc::new(move || -> Result<Box<dyn PartitionEvaluator>> {
203-
let evaluator = Python::attach(|py| {
204-
evaluator
205-
.call0(py)
206-
.map_err(|e| DataFusionError::Execution(e.to_string()))
207-
})?;
208-
Ok(Box::new(RustPartitionEvaluator::new(evaluator)))
209-
})
200+
fn instantiate_partition_evaluator(evaluator: &Py<PyAny>) -> Result<Box<dyn PartitionEvaluator>> {
201+
let instance = Python::attach(|py| {
202+
evaluator
203+
.call0(py)
204+
.map_err(|e| DataFusionError::Execution(e.to_string()))
205+
})?;
206+
Ok(Box::new(RustPartitionEvaluator::new(instance)))
210207
}
211208

212209
/// Represents an WindowUDF
@@ -234,14 +231,14 @@ impl PyWindowUDF {
234231
volatility: &str,
235232
) -> PyResult<Self> {
236233
let return_type = return_type.0;
237-
let input_types = input_types.into_iter().map(|t| t.0).collect();
234+
let input_types: Vec<DataType> = input_types.into_iter().map(|t| t.0).collect();
238235

239236
let function = WindowUDF::from(MultiColumnWindowUDF::new(
240237
name,
238+
evaluator,
241239
input_types,
242240
return_type,
243241
parse_volatility(volatility)?,
244-
to_rust_partition_evaluator(evaluator),
245242
));
246243
Ok(Self { function })
247244
}
@@ -278,42 +275,79 @@ impl PyWindowUDF {
278275
}
279276
}
280277

281-
#[derive(Hash, Eq, PartialEq)]
278+
#[derive(Debug)]
282279
pub struct MultiColumnWindowUDF {
283280
name: String,
281+
evaluator: Py<PyAny>,
284282
signature: Signature,
285283
return_type: DataType,
286-
partition_evaluator_factory: PtrEq<PartitionEvaluatorFactory>,
287-
}
288-
289-
impl std::fmt::Debug for MultiColumnWindowUDF {
290-
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
291-
f.debug_struct("WindowUDF")
292-
.field("name", &self.name)
293-
.field("signature", &self.signature)
294-
.field("return_type", &"<func>")
295-
.field("partition_evaluator_factory", &"<FUNC>")
296-
.finish()
297-
}
298284
}
299285

300286
impl MultiColumnWindowUDF {
301287
pub fn new(
302288
name: impl Into<String>,
289+
evaluator: Py<PyAny>,
303290
input_types: Vec<DataType>,
304291
return_type: DataType,
305292
volatility: Volatility,
306-
partition_evaluator_factory: PartitionEvaluatorFactory,
307293
) -> Self {
308294
let name = name.into();
309295
let signature = Signature::exact(input_types, volatility);
310296
Self {
311297
name,
298+
evaluator,
312299
signature,
313300
return_type,
314-
partition_evaluator_factory: partition_evaluator_factory.into(),
315301
}
316302
}
303+
304+
/// Stored Python callable that produces a fresh partition
305+
/// evaluator instance per partition. Consumed by the codec to
306+
/// cloudpickle the evaluator factory across process boundaries.
307+
pub(crate) fn evaluator(&self) -> &Py<PyAny> {
308+
&self.evaluator
309+
}
310+
311+
pub(crate) fn return_type(&self) -> &DataType {
312+
&self.return_type
313+
}
314+
315+
pub(crate) fn from_parts(
316+
name: String,
317+
evaluator: Py<PyAny>,
318+
input_types: Vec<DataType>,
319+
return_type: DataType,
320+
volatility: Volatility,
321+
) -> Self {
322+
Self::new(name, evaluator, input_types, return_type, volatility)
323+
}
324+
}
325+
326+
impl Eq for MultiColumnWindowUDF {}
327+
impl PartialEq for MultiColumnWindowUDF {
328+
fn eq(&self, other: &Self) -> bool {
329+
self.name == other.name
330+
&& self.signature == other.signature
331+
&& self.return_type == other.return_type
332+
&& Python::attach(|py| {
333+
self.evaluator
334+
.bind(py)
335+
.eq(other.evaluator.bind(py))
336+
.unwrap_or(false)
337+
})
338+
}
339+
}
340+
341+
impl std::hash::Hash for MultiColumnWindowUDF {
342+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
343+
self.name.hash(state);
344+
self.signature.hash(state);
345+
self.return_type.hash(state);
346+
Python::attach(|py| {
347+
let py_hash = self.evaluator.bind(py).hash().unwrap_or(0);
348+
state.write_isize(py_hash);
349+
});
350+
}
317351
}
318352

319353
impl WindowUDFImpl for MultiColumnWindowUDF {
@@ -339,7 +373,6 @@ impl WindowUDFImpl for MultiColumnWindowUDF {
339373
&self,
340374
_partition_evaluator_args: PartitionEvaluatorArgs,
341375
) -> Result<Box<dyn PartitionEvaluator>> {
342-
let _ = _partition_evaluator_args;
343-
(self.partition_evaluator_factory)()
376+
instantiate_partition_evaluator(&self.evaluator)
344377
}
345378
}

python/tests/test_pickle_expr.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,49 @@ def fn(arr):
130130
assert decoded.canonical_name() == e.canonical_name()
131131

132132

133+
class TestWindowUDFCodec:
134+
"""Python window UDFs travel inline like scalar UDFs."""
135+
136+
def _build_window_udf(self):
137+
from datafusion import udwf
138+
from datafusion.user_defined import WindowEvaluator
139+
140+
class CountUpEvaluator(WindowEvaluator):
141+
def evaluate_all(self, values, num_rows):
142+
return pa.array(list(range(num_rows)))
143+
144+
return udwf(
145+
CountUpEvaluator,
146+
[pa.int64()],
147+
pa.int64(),
148+
"immutable",
149+
name="count_up",
150+
)
151+
152+
def test_window_udf_self_contained_blob(self):
153+
u = self._build_window_udf()
154+
e = u(col("a"))
155+
blob = pickle.dumps(e)
156+
assert len(blob) > 200
157+
158+
def test_window_udf_decodes_into_fresh_ctx(self):
159+
u = self._build_window_udf()
160+
e = u(col("a"))
161+
blob = e.to_bytes()
162+
fresh = SessionContext()
163+
from datafusion import Expr
164+
165+
decoded = Expr.from_bytes(blob, ctx=fresh)
166+
assert "count_up" in decoded.canonical_name()
167+
168+
def test_window_udf_decodes_via_pickle_with_no_worker_ctx(self):
169+
u = self._build_window_udf()
170+
e = u(col("a"))
171+
blob = pickle.dumps(e)
172+
decoded = pickle.loads(blob)
173+
assert "count_up" in decoded.canonical_name()
174+
175+
133176
class TestWorkerCtxLifecycle:
134177
def test_set_and_clear(self):
135178
assert get_worker_ctx() is None

0 commit comments

Comments
 (0)