Skip to content

Commit fd46c94

Browse files
timsaucerclaude
andcommitted
refactor(codec): use arrow-rs native IPC for schema serialization
Previous round-trip went Rust Schema -> pyarrow Schema -> IPC bytes -> cloudpickle tuple -> pyarrow Schema -> Rust Schema, for both the input schema and the return Field. Two unnecessary pyarrow trips on each side. Replace with `StreamWriter::try_new(&mut buf, &schema)?.finish()?` on the encoder and `StreamReader::try_new(cursor, None)?.schema()` on the decoder. Both ends produce / consume the same Arrow IPC stream bytes — arrow-rs writes a schema-only stream, arrow-rs reads it back, no PyArrow involvement. Tuple shape changes slightly: the fourth field is now a one-field `return_schema_bytes` IPC blob instead of a pickled pyarrow `Field`. Keeps everything in `Vec<u8>` form before cloudpickle picks it up. `pyarrow.ipc.read_schema` and the `ToPyArrow` / `FromPyArrow` traits on `Schema` / `Field` are no longer needed on the codec hot path, shaving a noticeable chunk of pyarrow function dispatch from each encode / decode call. Pickle tests still green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0450713 commit fd46c94

1 file changed

Lines changed: 59 additions & 34 deletions

File tree

crates/core/src/codec.rs

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@
7878
use std::sync::Arc;
7979

8080
use arrow::datatypes::{Field, Schema, SchemaRef};
81-
use arrow::pyarrow::ToPyArrow;
82-
use datafusion::arrow::pyarrow::FromPyArrow;
81+
use arrow::ipc::reader::StreamReader;
82+
use arrow::ipc::writer::StreamWriter;
8383
use datafusion::common::{Result, TableReference};
8484
use datafusion::datasource::TableProvider;
8585
use datafusion::datasource::file_format::FileFormatFactory;
@@ -91,7 +91,6 @@ use datafusion::physical_expr::PhysicalExpr;
9191
use datafusion::physical_plan::ExecutionPlan;
9292
use datafusion_proto::logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec};
9393
use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec};
94-
use pyo3::BoundObject;
9594
use pyo3::prelude::*;
9695
use pyo3::types::{PyBytes, PyTuple};
9796

@@ -357,14 +356,14 @@ pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result<Option<Arc<Scal
357356
/// Build the cloudpickle payload for a `PythonFunctionScalarUDF`.
358357
///
359358
/// Layout: `cloudpickle.dumps((name, func, input_schema_bytes,
360-
/// return_field, volatility_str))`. Input `DataType`s are derived
361-
/// from the UDF's `Signature` (always `TypeSignature::Exact` for
362-
/// Python-defined UDFs) and packaged as a pyarrow `Schema` with
363-
/// synthesized field names — the local `PythonFunctionScalarUDF`
364-
/// does not retain field-level metadata, and reconstructing one on
365-
/// the receiver via `from_parts` immediately collapses any incoming
366-
/// `Field` info back to `DataType`, so the original sender-side
367-
/// fields and receiver-side fields are functionally equivalent.
359+
/// return_schema_bytes, volatility_str))`. Both schema blobs are
360+
/// produced by arrow-rs's native IPC stream writer — no pyarrow
361+
/// round-trip — and decoded with the matching stream reader on the
362+
/// receiver. Input `DataType`s come from the UDF's `Signature`
363+
/// (always `TypeSignature::Exact` for Python-defined UDFs);
364+
/// receiver-side `Field` metadata is irrelevant because
365+
/// `from_parts` immediately collapses incoming `Field`s back to
366+
/// `DataType`s for the reconstructed `Signature`.
368367
fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> PyResult<Vec<u8>> {
369368
let cloudpickle = py.import("cloudpickle")?;
370369

@@ -377,29 +376,28 @@ fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> Py
377376
)));
378377
}
379378
};
380-
let fields: Vec<Field> = input_dtypes
379+
let input_fields: Vec<Field> = input_dtypes
381380
.into_iter()
382381
.enumerate()
383382
.map(|(i, dt)| Field::new(format!("arg_{i}"), dt, true))
384383
.collect();
385-
let input_schema = Schema::new(fields);
386-
let pa_schema_obj = input_schema.to_pyarrow(py)?;
387-
let pa_schema = pa_schema_obj.into_bound();
388-
let schema_bytes: Vec<u8> = pa_schema
389-
.call_method0("serialize")?
390-
.call_method0("to_pybytes")?
391-
.extract()?;
392-
393-
let return_field_obj = udf.return_field().as_ref().to_pyarrow(py)?;
384+
let input_schema = Schema::new(input_fields);
385+
let input_schema_bytes = schema_to_ipc_bytes(&input_schema)
386+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
387+
388+
let return_schema = Schema::new(vec![udf.return_field().as_ref().clone()]);
389+
let return_schema_bytes = schema_to_ipc_bytes(&return_schema)
390+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
391+
394392
let volatility = format!("{:?}", signature.volatility).to_lowercase();
395393

396394
let payload = PyTuple::new(
397395
py,
398396
[
399397
udf.name().into_pyobject(py)?.into_any(),
400398
udf.func().bind(py).clone().into_any(),
401-
PyBytes::new(py, &schema_bytes).into_any(),
402-
return_field_obj.into_bound(),
399+
PyBytes::new(py, &input_schema_bytes).into_any(),
400+
PyBytes::new(py, &return_schema_bytes).into_any(),
403401
volatility.into_pyobject(py)?.into_any(),
404402
],
405403
)?;
@@ -411,29 +409,37 @@ fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> Py
411409
/// Inverse of [`encode_python_scalar_udf`].
412410
fn decode_python_scalar_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionScalarUDF> {
413411
let cloudpickle = py.import("cloudpickle")?;
414-
let pyarrow = py.import("pyarrow")?;
415412

416413
let tuple = cloudpickle
417414
.call_method1("loads", (PyBytes::new(py, payload),))?
418415
.cast_into::<PyTuple>()?;
419416

420417
let name: String = tuple.get_item(0)?.extract()?;
421418
let func: Py<PyAny> = tuple.get_item(1)?.unbind();
422-
let schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?;
423-
let return_field_py = tuple.get_item(3)?;
419+
let input_schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?;
420+
let return_schema_bytes: Vec<u8> = tuple.get_item(3)?.extract()?;
424421
let volatility_str: String = tuple.get_item(4)?.extract()?;
425422

426-
let buffer = pyarrow.call_method1("py_buffer", (PyBytes::new(py, &schema_bytes),))?;
427-
let pa_schema = pyarrow
428-
.getattr("ipc")?
429-
.call_method1("read_schema", (buffer,))?;
430-
431-
let schema = Schema::from_pyarrow_bound(&pa_schema)
423+
let input_schema = schema_from_ipc_bytes(&input_schema_bytes)
432424
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
433-
let input_fields: Vec<Field> = schema.fields().iter().map(|f| f.as_ref().clone()).collect();
425+
let input_fields: Vec<Field> = input_schema
426+
.fields()
427+
.iter()
428+
.map(|f| f.as_ref().clone())
429+
.collect();
434430

435-
let return_field = Field::from_pyarrow_bound(&return_field_py)
431+
let return_schema = schema_from_ipc_bytes(&return_schema_bytes)
436432
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
433+
let return_field = return_schema
434+
.fields()
435+
.first()
436+
.ok_or_else(|| {
437+
pyo3::exceptions::PyValueError::new_err(
438+
"PythonFunctionScalarUDF return schema must contain exactly one field",
439+
)
440+
})?
441+
.as_ref()
442+
.clone();
437443

438444
let volatility = datafusion_python_util::parse_volatility(&volatility_str)
439445
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
@@ -446,3 +452,22 @@ fn decode_python_scalar_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFu
446452
volatility,
447453
))
448454
}
455+
456+
/// Serialize a `Schema` to a self-contained IPC stream containing
457+
/// only the schema message (no record batches). Inverse:
458+
/// [`schema_from_ipc_bytes`].
459+
fn schema_to_ipc_bytes(schema: &Schema) -> arrow::error::Result<Vec<u8>> {
460+
let mut buf: Vec<u8> = Vec::new();
461+
{
462+
let mut writer = StreamWriter::try_new(&mut buf, schema)?;
463+
writer.finish()?;
464+
}
465+
Ok(buf)
466+
}
467+
468+
/// Decode an IPC stream containing only a schema message back into a
469+
/// `Schema`. Inverse: [`schema_to_ipc_bytes`].
470+
fn schema_from_ipc_bytes(bytes: &[u8]) -> arrow::error::Result<Schema> {
471+
let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?;
472+
Ok(reader.schema().as_ref().clone())
473+
}

0 commit comments

Comments
 (0)