7878use std:: sync:: Arc ;
7979
8080use arrow:: datatypes:: { Field , Schema , SchemaRef } ;
81- use arrow:: pyarrow :: ToPyArrow ;
82- use datafusion :: arrow :: pyarrow :: FromPyArrow ;
81+ use arrow:: ipc :: reader :: StreamReader ;
82+ use arrow :: ipc :: writer :: StreamWriter ;
8383use datafusion:: common:: { Result , TableReference } ;
8484use datafusion:: datasource:: TableProvider ;
8585use datafusion:: datasource:: file_format:: FileFormatFactory ;
@@ -91,7 +91,6 @@ use datafusion::physical_expr::PhysicalExpr;
9191use datafusion:: physical_plan:: ExecutionPlan ;
9292use datafusion_proto:: logical_plan:: { DefaultLogicalExtensionCodec , LogicalExtensionCodec } ;
9393use datafusion_proto:: physical_plan:: { DefaultPhysicalExtensionCodec , PhysicalExtensionCodec } ;
94- use pyo3:: BoundObject ;
9594use pyo3:: prelude:: * ;
9695use pyo3:: types:: { PyBytes , PyTuple } ;
9796
@@ -357,14 +356,14 @@ pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result<Option<Arc<Scal
357356/// Build the cloudpickle payload for a `PythonFunctionScalarUDF`.
358357///
359358/// Layout: `cloudpickle.dumps((name, func, input_schema_bytes,
360- /// return_field , volatility_str))`. Input `DataType`s are derived
361- /// from the UDF 's `Signature` (always `TypeSignature::Exact` for
362- /// Python-defined UDFs) and packaged as a pyarrow `Schema` with
363- /// synthesized field names — the local `PythonFunctionScalarUDF `
364- /// does not retain field-level metadata, and reconstructing one on
365- /// the receiver via `from_parts` immediately collapses any incoming
366- /// `Field` info back to `DataType`, so the original sender-side
367- /// fields and receiver-side fields are functionally equivalent .
359+ /// return_schema_bytes , volatility_str))`. Both schema blobs are
360+ /// produced by arrow-rs 's native IPC stream writer — no pyarrow
361+ /// round-trip — and decoded with the matching stream reader on the
362+ /// receiver. Input `DataType`s come from the UDF's `Signature `
363+ /// (always `TypeSignature::Exact` for Python-defined UDFs);
364+ /// receiver-side `Field` metadata is irrelevant because
365+ /// `from_parts` immediately collapses incoming `Field`s back to
366+ /// `DataType`s for the reconstructed `Signature` .
368367fn encode_python_scalar_udf ( py : Python < ' _ > , udf : & PythonFunctionScalarUDF ) -> PyResult < Vec < u8 > > {
369368 let cloudpickle = py. import ( "cloudpickle" ) ?;
370369
@@ -377,29 +376,28 @@ fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> Py
377376 ) ) ) ;
378377 }
379378 } ;
380- let fields : Vec < Field > = input_dtypes
379+ let input_fields : Vec < Field > = input_dtypes
381380 . into_iter ( )
382381 . enumerate ( )
383382 . map ( |( i, dt) | Field :: new ( format ! ( "arg_{i}" ) , dt, true ) )
384383 . collect ( ) ;
385- let input_schema = Schema :: new ( fields) ;
386- let pa_schema_obj = input_schema. to_pyarrow ( py) ?;
387- let pa_schema = pa_schema_obj. into_bound ( ) ;
388- let schema_bytes: Vec < u8 > = pa_schema
389- . call_method0 ( "serialize" ) ?
390- . call_method0 ( "to_pybytes" ) ?
391- . extract ( ) ?;
392-
393- let return_field_obj = udf. return_field ( ) . as_ref ( ) . to_pyarrow ( py) ?;
384+ let input_schema = Schema :: new ( input_fields) ;
385+ let input_schema_bytes = schema_to_ipc_bytes ( & input_schema)
386+ . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
387+
388+ let return_schema = Schema :: new ( vec ! [ udf. return_field( ) . as_ref( ) . clone( ) ] ) ;
389+ let return_schema_bytes = schema_to_ipc_bytes ( & return_schema)
390+ . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
391+
394392 let volatility = format ! ( "{:?}" , signature. volatility) . to_lowercase ( ) ;
395393
396394 let payload = PyTuple :: new (
397395 py,
398396 [
399397 udf. name ( ) . into_pyobject ( py) ?. into_any ( ) ,
400398 udf. func ( ) . bind ( py) . clone ( ) . into_any ( ) ,
401- PyBytes :: new ( py, & schema_bytes ) . into_any ( ) ,
402- return_field_obj . into_bound ( ) ,
399+ PyBytes :: new ( py, & input_schema_bytes ) . into_any ( ) ,
400+ PyBytes :: new ( py , & return_schema_bytes ) . into_any ( ) ,
403401 volatility. into_pyobject ( py) ?. into_any ( ) ,
404402 ] ,
405403 ) ?;
@@ -411,29 +409,37 @@ fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> Py
411409/// Inverse of [`encode_python_scalar_udf`].
412410fn decode_python_scalar_udf ( py : Python < ' _ > , payload : & [ u8 ] ) -> PyResult < PythonFunctionScalarUDF > {
413411 let cloudpickle = py. import ( "cloudpickle" ) ?;
414- let pyarrow = py. import ( "pyarrow" ) ?;
415412
416413 let tuple = cloudpickle
417414 . call_method1 ( "loads" , ( PyBytes :: new ( py, payload) , ) ) ?
418415 . cast_into :: < PyTuple > ( ) ?;
419416
420417 let name: String = tuple. get_item ( 0 ) ?. extract ( ) ?;
421418 let func: Py < PyAny > = tuple. get_item ( 1 ) ?. unbind ( ) ;
422- let schema_bytes : Vec < u8 > = tuple. get_item ( 2 ) ?. extract ( ) ?;
423- let return_field_py = tuple. get_item ( 3 ) ?;
419+ let input_schema_bytes : Vec < u8 > = tuple. get_item ( 2 ) ?. extract ( ) ?;
420+ let return_schema_bytes : Vec < u8 > = tuple. get_item ( 3 ) ? . extract ( ) ?;
424421 let volatility_str: String = tuple. get_item ( 4 ) ?. extract ( ) ?;
425422
426- let buffer = pyarrow. call_method1 ( "py_buffer" , ( PyBytes :: new ( py, & schema_bytes) , ) ) ?;
427- let pa_schema = pyarrow
428- . getattr ( "ipc" ) ?
429- . call_method1 ( "read_schema" , ( buffer, ) ) ?;
430-
431- let schema = Schema :: from_pyarrow_bound ( & pa_schema)
423+ let input_schema = schema_from_ipc_bytes ( & input_schema_bytes)
432424 . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
433- let input_fields: Vec < Field > = schema. fields ( ) . iter ( ) . map ( |f| f. as_ref ( ) . clone ( ) ) . collect ( ) ;
425+ let input_fields: Vec < Field > = input_schema
426+ . fields ( )
427+ . iter ( )
428+ . map ( |f| f. as_ref ( ) . clone ( ) )
429+ . collect ( ) ;
434430
435- let return_field = Field :: from_pyarrow_bound ( & return_field_py )
431+ let return_schema = schema_from_ipc_bytes ( & return_schema_bytes )
436432 . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
433+ let return_field = return_schema
434+ . fields ( )
435+ . first ( )
436+ . ok_or_else ( || {
437+ pyo3:: exceptions:: PyValueError :: new_err (
438+ "PythonFunctionScalarUDF return schema must contain exactly one field" ,
439+ )
440+ } ) ?
441+ . as_ref ( )
442+ . clone ( ) ;
437443
438444 let volatility = datafusion_python_util:: parse_volatility ( & volatility_str)
439445 . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
@@ -446,3 +452,22 @@ fn decode_python_scalar_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFu
446452 volatility,
447453 ) )
448454}
455+
456+ /// Serialize a `Schema` to a self-contained IPC stream containing
457+ /// only the schema message (no record batches). Inverse:
458+ /// [`schema_from_ipc_bytes`].
459+ fn schema_to_ipc_bytes ( schema : & Schema ) -> arrow:: error:: Result < Vec < u8 > > {
460+ let mut buf: Vec < u8 > = Vec :: new ( ) ;
461+ {
462+ let mut writer = StreamWriter :: try_new ( & mut buf, schema) ?;
463+ writer. finish ( ) ?;
464+ }
465+ Ok ( buf)
466+ }
467+
468+ /// Decode an IPC stream containing only a schema message back into a
469+ /// `Schema`. Inverse: [`schema_to_ipc_bytes`].
470+ fn schema_from_ipc_bytes ( bytes : & [ u8 ] ) -> arrow:: error:: Result < Schema > {
471+ let reader = StreamReader :: try_new ( std:: io:: Cursor :: new ( bytes) , None ) ?;
472+ Ok ( reader. schema ( ) . as_ref ( ) . clone ( ) )
473+ }
0 commit comments