Skip to content

Commit cc1bff7

Browse files
timsaucerclaude
andcommitted
feat: pickle support for Expr via PythonLogicalCodec inline encoding
Builds on the codec consistency work in feat/proto-codecs. Python scalar UDFs are cloudpickled inline into the proto `fun_definition` field by PythonLogicalCodec / PythonPhysicalCodec, so a pickled Expr that references a Python `udf()` reconstructs on the receiver with no pre-registration. UDAFs, UDWFs, and FFI-imported UDFs still resolve through the receiver's session. Rust: * `PythonFunctionScalarUDF` regains the `func()` / `input_fields()` / `return_field()` / `volatility()` / `from_parts()` accessors the codec needs. * `crates/core/src/codec.rs` adds shared `try_encode_python_scalar_udf` / `try_decode_python_scalar_udf` helpers built on cloudpickle + pyarrow IPC for the input schema. Both `PythonLogicalCodec.try_encode_udf` and `PythonPhysicalCodec.try_encode_udf` consult the helper first and fall back to `inner` for non-Python UDFs (and the receiver's function registry on decode if the prefix does not match). Python: * `datafusion.ipc` module: thread-local `set_worker_ctx` / `clear_worker_ctx` / `get_worker_ctx` for installing a receiver `SessionContext` on a worker process. `_resolve_ctx` returns explicit > worker > fresh. * `Expr.__reduce__` returns `(Expr._reconstruct, (self.to_bytes(),))`. `_reconstruct` calls `Expr.from_bytes(buf, ctx=None)` which consults the worker context. * `Expr.from_bytes` signature switches to `(buf, ctx=None)` (was `(ctx, buf)`); no callers in main, only PR1 tests which are updated. * `datafusion.ipc` exported from the top-level package. Dependencies: * `cloudpickle>=2.0` added as a runtime dep. Lazy-imported on the encode / decode hot paths — users who never pickle a plan or expression pay only the install footprint, not import-time cost. * ruff `S301` added to the test-suite + examples ignore lists (legitimate `pickle.loads` use). Tests: * `test_pickle_expr.py` — 11 cases covering built-in expr pickle, scalar UDF self-contained blobs, closure-capturing UDFs, worker ctx lifecycle, thread-local isolation. * `test_pickle_multiprocessing.py` + `_pickle_multiprocessing_helpers.py` — parametrized over `fork`/`forkserver`/`spawn` start methods. 9 cases. Auto-skip when the sandbox blocks semaphore creation; CI runs the full matrix. * `test_expr.py` — existing `from_bytes` tests updated to new signature. 1088 root tests pass (up from 1077), 13 skipped (up from 4, the new mp cases skip locally under sandboxed semaphores). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent baef8f0 commit cc1bff7

11 files changed

Lines changed: 1565 additions & 865 deletions

File tree

crates/core/src/codec.rs

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,22 +77,30 @@
7777
7878
use std::sync::Arc;
7979

80-
use arrow::datatypes::SchemaRef;
80+
use arrow::datatypes::{Field, Schema, SchemaRef};
81+
use arrow::pyarrow::ToPyArrow;
82+
use datafusion::arrow::pyarrow::FromPyArrow;
8183
use datafusion::common::{Result, TableReference};
8284
use datafusion::datasource::TableProvider;
8385
use datafusion::datasource::file_format::FileFormatFactory;
8486
use datafusion::execution::TaskContext;
85-
use datafusion::logical_expr::{AggregateUDF, Extension, LogicalPlan, ScalarUDF, WindowUDF};
87+
use datafusion::logical_expr::{
88+
AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, WindowUDF,
89+
};
8690
use datafusion::physical_expr::PhysicalExpr;
8791
use datafusion::physical_plan::ExecutionPlan;
8892
use datafusion_proto::logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec};
8993
use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec};
94+
use pyo3::BoundObject;
95+
use pyo3::prelude::*;
96+
use pyo3::types::{PyBytes, PyTuple};
97+
98+
use crate::udf::PythonFunctionScalarUDF;
9099

91100
/// Wire-format prefix that tags a `fun_definition` payload as an
92101
/// inlined Python scalar UDF (cloudpickled tuple of name, callable,
93102
/// input schema, return field, volatility). Defined once here so
94103
/// the encoder and decoder cannot drift.
95-
#[allow(dead_code)]
96104
pub(crate) const PY_SCALAR_UDF_MAGIC: &[u8] = b"DFPYUDF1";
97105

98106
/// `LogicalExtensionCodec` parked on every `SessionContext`. Holds
@@ -177,10 +185,16 @@ impl LogicalExtensionCodec for PythonLogicalCodec {
177185
}
178186

179187
fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<()> {
188+
if try_encode_python_scalar_udf(node, buf)? {
189+
return Ok(());
190+
}
180191
self.inner.try_encode_udf(node, buf)
181192
}
182193

183194
fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> {
195+
if let Some(udf) = try_decode_python_scalar_udf(buf)? {
196+
return Ok(udf);
197+
}
184198
self.inner.try_decode_udf(name, buf)
185199
}
186200

@@ -249,10 +263,16 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec {
249263
}
250264

251265
fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<()> {
266+
if try_encode_python_scalar_udf(node, buf)? {
267+
return Ok(());
268+
}
252269
self.inner.try_encode_udf(node, buf)
253270
}
254271

255272
fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> {
273+
if let Some(udf) = try_decode_python_scalar_udf(buf)? {
274+
return Ok(udf);
275+
}
256276
self.inner.try_decode_udf(name, buf)
257277
}
258278

@@ -284,3 +304,126 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec {
284304
self.inner.try_decode_udwf(name, buf)
285305
}
286306
}
307+
308+
// =============================================================================
309+
// Shared Python scalar UDF encode / decode helpers
310+
//
311+
// Both `PythonLogicalCodec` and `PythonPhysicalCodec` consult these on
312+
// every `try_encode_udf` / `try_decode_udf` call. Same wire format on
313+
// both layers — a Python `ScalarUDF` referenced inside a `LogicalPlan`
314+
// or an `ExecutionPlan` round-trips identically.
315+
// =============================================================================
316+
317+
/// Encode a Python scalar UDF inline if `node` is one. Returns
318+
/// `Ok(true)` when the payload (`DFPYUDF1` prefix + cloudpickled
319+
/// tuple) was written and the caller should skip its inner codec.
320+
/// Returns `Ok(false)` for any non-Python UDF, signalling the caller
321+
/// to delegate to its `inner`.
322+
pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<bool> {
323+
let Some(py_udf) = node
324+
.inner()
325+
.as_any()
326+
.downcast_ref::<PythonFunctionScalarUDF>()
327+
else {
328+
return Ok(false);
329+
};
330+
331+
Python::attach(|py| -> Result<bool> {
332+
let bytes = encode_python_scalar_udf(py, py_udf)
333+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
334+
buf.extend_from_slice(PY_SCALAR_UDF_MAGIC);
335+
buf.extend_from_slice(&bytes);
336+
Ok(true)
337+
})
338+
}
339+
340+
/// Decode an inline Python scalar UDF payload. Returns `Ok(None)`
341+
/// when `buf` does not carry the `DFPYUDF1` prefix, signalling the
342+
/// caller to delegate to its `inner` codec (and eventually the
343+
/// `FunctionRegistry`).
344+
pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result<Option<Arc<ScalarUDF>>> {
345+
if buf.is_empty() || !buf.starts_with(PY_SCALAR_UDF_MAGIC) {
346+
return Ok(None);
347+
}
348+
let payload = &buf[PY_SCALAR_UDF_MAGIC.len()..];
349+
350+
Python::attach(|py| -> Result<Option<Arc<ScalarUDF>>> {
351+
let udf = decode_python_scalar_udf(py, payload)
352+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
353+
Ok(Some(Arc::new(ScalarUDF::new_from_impl(udf))))
354+
})
355+
}
356+
357+
/// Build the cloudpickle payload for a `PythonFunctionScalarUDF`.
358+
///
359+
/// Layout: `cloudpickle.dumps((name, func, input_schema_bytes,
360+
/// return_field, volatility_str))`. Input fields ride along as an
361+
/// IPC-encoded pyarrow Schema so they round-trip without extra
362+
/// plumbing.
363+
fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> PyResult<Vec<u8>> {
364+
let cloudpickle = py.import("cloudpickle")?;
365+
366+
let input_schema = Schema::new(udf.input_fields().to_vec());
367+
let pa_schema_obj = input_schema.to_pyarrow(py)?;
368+
let pa_schema = pa_schema_obj.into_bound();
369+
let schema_bytes: Vec<u8> = pa_schema
370+
.call_method0("serialize")?
371+
.call_method0("to_pybytes")?
372+
.extract()?;
373+
374+
let return_field_obj = udf.return_field().as_ref().to_pyarrow(py)?;
375+
let volatility = format!("{:?}", udf.volatility()).to_lowercase();
376+
377+
let payload = PyTuple::new(
378+
py,
379+
[
380+
udf.name().into_pyobject(py)?.into_any(),
381+
udf.func().bind(py).clone().into_any(),
382+
PyBytes::new(py, &schema_bytes).into_any(),
383+
return_field_obj.into_bound(),
384+
volatility.into_pyobject(py)?.into_any(),
385+
],
386+
)?;
387+
388+
let blob = cloudpickle.call_method1("dumps", (payload,))?;
389+
blob.extract::<Vec<u8>>()
390+
}
391+
392+
/// Inverse of [`encode_python_scalar_udf`].
393+
fn decode_python_scalar_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionScalarUDF> {
394+
let cloudpickle = py.import("cloudpickle")?;
395+
let pyarrow = py.import("pyarrow")?;
396+
397+
let tuple = cloudpickle
398+
.call_method1("loads", (PyBytes::new(py, payload),))?
399+
.cast_into::<PyTuple>()?;
400+
401+
let name: String = tuple.get_item(0)?.extract()?;
402+
let func: Py<PyAny> = tuple.get_item(1)?.unbind();
403+
let schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?;
404+
let return_field_py = tuple.get_item(3)?;
405+
let volatility_str: String = tuple.get_item(4)?.extract()?;
406+
407+
let buffer = pyarrow.call_method1("py_buffer", (PyBytes::new(py, &schema_bytes),))?;
408+
let pa_schema = pyarrow
409+
.getattr("ipc")?
410+
.call_method1("read_schema", (buffer,))?;
411+
412+
let schema = Schema::from_pyarrow_bound(&pa_schema)
413+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
414+
let input_fields: Vec<Field> = schema.fields().iter().map(|f| f.as_ref().clone()).collect();
415+
416+
let return_field = Field::from_pyarrow_bound(&return_field_py)
417+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
418+
419+
let volatility = datafusion_python_util::parse_volatility(&volatility_str)
420+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
421+
422+
Ok(PythonFunctionScalarUDF::from_parts(
423+
name,
424+
func,
425+
input_fields,
426+
return_field,
427+
volatility,
428+
))
429+
}

crates/core/src/udf.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@ use crate::expr::PyExpr;
4343
/// This struct holds the Python written function that is a
4444
/// ScalarUDF.
4545
#[derive(Debug)]
46-
struct PythonFunctionScalarUDF {
46+
pub(crate) struct PythonFunctionScalarUDF {
4747
name: String,
4848
func: Py<PyAny>,
49-
signature: Signature,
49+
input_fields: Vec<Field>,
5050
return_field: FieldRef,
51+
signature: Signature,
52+
volatility: Volatility,
5153
}
5254

5355
impl PythonFunctionScalarUDF {
@@ -63,10 +65,40 @@ impl PythonFunctionScalarUDF {
6365
Self {
6466
name,
6567
func,
66-
signature,
68+
input_fields,
6769
return_field: Arc::new(return_field),
70+
signature,
71+
volatility,
6872
}
6973
}
74+
75+
/// Stored Python callable. Consumed by the codec to cloudpickle
76+
/// the function body across process boundaries.
77+
pub(crate) fn func(&self) -> &Py<PyAny> {
78+
&self.func
79+
}
80+
81+
pub(crate) fn input_fields(&self) -> &[Field] {
82+
&self.input_fields
83+
}
84+
85+
pub(crate) fn return_field(&self) -> &FieldRef {
86+
&self.return_field
87+
}
88+
89+
pub(crate) fn volatility(&self) -> Volatility {
90+
self.volatility
91+
}
92+
93+
pub(crate) fn from_parts(
94+
name: String,
95+
func: Py<PyAny>,
96+
input_fields: Vec<Field>,
97+
return_field: Field,
98+
volatility: Volatility,
99+
) -> Self {
100+
Self::new(name, func, input_fields, return_field, volatility)
101+
}
70102
}
71103

72104
impl Eq for PythonFunctionScalarUDF {}

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ classifiers = [
4444
"Programming Language :: Rust",
4545
]
4646
dependencies = [
47+
# cloudpickle is invoked by the Rust-side PythonLogicalCodec /
48+
# PythonPhysicalCodec via pyo3 to serialize Python scalar UDF
49+
# callables into the proto wire format. Lazy-imported on the encode
50+
# / decode hot paths, so users who never serialize a plan or
51+
# expression incur no runtime cost beyond the install footprint.
52+
"cloudpickle>=2.0",
4753
"pyarrow>=16.0.0;python_version<'3.14'",
4854
"pyarrow>=22.0.0;python_version>='3.14'",
4955
"typing-extensions;python_version<'3.13'",
@@ -120,6 +126,7 @@ extend-allowed-calls = ["datafusion.lit", "lit"]
120126
"PT011",
121127
"RUF015",
122128
"S101",
129+
"S301",
123130
"S608",
124131
"SLF",
125132
]
@@ -133,6 +140,7 @@ extend-allowed-calls = ["datafusion.lit", "lit"]
133140
"PLR2004",
134141
"RUF015",
135142
"S101",
143+
"S301",
136144
"T201",
137145
"W505",
138146
]

python/datafusion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
import importlib_metadata # type: ignore[import]
6666

6767
# Public submodules
68-
from . import functions, object_store, substrait, unparser
68+
from . import functions, ipc, object_store, substrait, unparser
6969

7070
# The following imports are okay to remain as opaque to the user.
7171
from ._internal import Config
@@ -142,6 +142,7 @@
142142
"configure_formatter",
143143
"expr",
144144
"functions",
145+
"ipc",
145146
"lit",
146147
"literal",
147148
"object_store",

python/datafusion/expr.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -436,21 +436,51 @@ def variant_name(self) -> str:
436436
def to_bytes(self, ctx: SessionContext | None = None) -> bytes:
437437
"""Serialize this expression to protobuf bytes.
438438
439-
When ``ctx`` is supplied, encoding routes through the session's
440-
installed :class:`LogicalExtensionCodec`. Without ``ctx`` a
441-
default codec is used.
439+
Python scalar UDFs are cloudpickled inline by
440+
:class:`PythonLogicalCodec`, so the returned blob is
441+
self-contained for scalar UDFs. Aggregate / window / FFI UDFs
442+
are stored by name only; the receiver must have them
443+
registered.
444+
445+
When ``ctx`` is supplied, encoding also routes through the
446+
session's installed codec stack.
442447
"""
443448
ctx_arg = ctx.ctx if ctx is not None else None
444-
return self.expr.to_bytes(ctx_arg)
449+
return bytes(self.expr.to_bytes(ctx_arg))
445450

446-
@staticmethod
447-
def from_bytes(ctx: SessionContext, data: bytes) -> Expr:
451+
@classmethod
452+
def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr:
448453
"""Decode an expression from serialized protobuf bytes.
449454
450-
``ctx`` provides the function registry for resolving UDF
451-
references and the logical codec for in-band Python payloads.
455+
``ctx`` is the receiver :class:`SessionContext` used to resolve
456+
function references not inlined by the codec (aggregate UDFs,
457+
window UDFs, FFI UDFs). When ``ctx`` is ``None`` the worker
458+
context set via :func:`datafusion.ipc.set_worker_ctx` is
459+
consulted; if no worker context is set, a fresh
460+
:class:`SessionContext` is used.
461+
"""
462+
from datafusion.ipc import _resolve_ctx
463+
464+
resolved = _resolve_ctx(ctx)
465+
return cls(expr_internal.RawExpr.from_bytes(resolved.ctx, buf))
466+
467+
def __reduce__(self) -> tuple:
468+
"""Pickle protocol hook.
469+
470+
:class:`PythonLogicalCodec` cloudpickles referenced Python
471+
scalar UDFs directly into the proto wire format, so the
472+
returned blob is self-contained. On unpickle the bytes are
473+
decoded against the worker context set via
474+
:func:`datafusion.ipc.set_worker_ctx` (or a fresh
475+
:class:`SessionContext` if none) for any remaining
476+
registry-resolved references.
452477
"""
453-
return Expr(expr_internal.RawExpr.from_bytes(ctx.ctx, data))
478+
return (Expr._reconstruct, (self.to_bytes(),))
479+
480+
@classmethod
481+
def _reconstruct(cls, proto_bytes: bytes) -> Expr:
482+
"""Internal entry point used by :meth:`__reduce__` on unpickle."""
483+
return cls.from_bytes(proto_bytes)
454484

455485
def __richcmp__(self, other: Expr, op: int) -> Expr:
456486
"""Comparison operator."""

0 commit comments

Comments
 (0)