@@ -85,8 +85,8 @@ use datafusion::datasource::TableProvider;
8585use datafusion:: datasource:: file_format:: FileFormatFactory ;
8686use datafusion:: execution:: TaskContext ;
8787use datafusion:: logical_expr:: {
88- AggregateUDF , Extension , LogicalPlan , ScalarUDF , ScalarUDFImpl , TypeSignature , WindowUDF ,
89- WindowUDFImpl ,
88+ AggregateUDF , AggregateUDFImpl , Extension , LogicalPlan , ScalarUDF , ScalarUDFImpl ,
89+ TypeSignature , WindowUDF , WindowUDFImpl ,
9090} ;
9191use datafusion:: physical_expr:: PhysicalExpr ;
9292use datafusion:: physical_plan:: ExecutionPlan ;
@@ -95,6 +95,7 @@ use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExt
9595use pyo3:: prelude:: * ;
9696use pyo3:: types:: { PyBytes , PyTuple } ;
9797
98+ use crate :: udaf:: PythonFunctionAggregateUDF ;
9899use crate :: udf:: PythonFunctionScalarUDF ;
99100use crate :: udwf:: MultiColumnWindowUDF ;
100101
@@ -104,6 +105,11 @@ use crate::udwf::MultiColumnWindowUDF;
104105/// the encoder and decoder cannot drift.
105106pub ( crate ) const PY_SCALAR_UDF_MAGIC : & [ u8 ] = b"DFPYUDF1" ;
106107
108+ /// Wire-format prefix for an inlined Python aggregate UDF
109+ /// (cloudpickled tuple of name, accumulator factory, input schema,
110+ /// return type, state types schema, volatility).
111+ pub ( crate ) const PY_AGG_UDF_MAGIC : & [ u8 ] = b"DFPYUDA1" ;
112+
107113/// Wire-format prefix for an inlined Python window UDF (cloudpickled
108114/// tuple of name, evaluator factory, input schema, return type,
109115/// volatility).
@@ -205,10 +211,16 @@ impl LogicalExtensionCodec for PythonLogicalCodec {
205211 }
206212
207213 fn try_encode_udaf ( & self , node : & AggregateUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
214+ if try_encode_python_agg_udf ( node, buf) ? {
215+ return Ok ( ( ) ) ;
216+ }
208217 self . inner . try_encode_udaf ( node, buf)
209218 }
210219
211220 fn try_decode_udaf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < AggregateUDF > > {
221+ if let Some ( udaf) = try_decode_python_agg_udf ( buf) ? {
222+ return Ok ( udaf) ;
223+ }
212224 self . inner . try_decode_udaf ( name, buf)
213225 }
214226
@@ -301,10 +313,16 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec {
301313 }
302314
303315 fn try_encode_udaf ( & self , node : & AggregateUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
316+ if try_encode_python_agg_udf ( node, buf) ? {
317+ return Ok ( ( ) ) ;
318+ }
304319 self . inner . try_encode_udaf ( node, buf)
305320 }
306321
307322 fn try_decode_udaf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < AggregateUDF > > {
323+ if let Some ( udaf) = try_decode_python_agg_udf ( buf) ? {
324+ return Ok ( udaf) ;
325+ }
308326 self . inner . try_decode_udaf ( name, buf)
309327 }
310328
@@ -613,3 +631,149 @@ fn decode_python_window_udf(py: Python<'_>, payload: &[u8]) -> PyResult<MultiCol
613631 volatility,
614632 ) )
615633}
634+
635+ // =============================================================================
636+ // Shared Python aggregate UDF encode / decode helpers
637+ //
638+ // Cloudpickle tuple shape: `(name, accumulator_factory, input_schema_bytes,
639+ // return_type_bytes, state_schema_bytes, volatility_str)`. The accumulator
640+ // factory is the Python callable that produces a new accumulator instance
641+ // per partition.
642+ // =============================================================================
643+
644+ pub ( crate ) fn try_encode_python_agg_udf ( node : & AggregateUDF , buf : & mut Vec < u8 > ) -> Result < bool > {
645+ let Some ( py_udf) = node
646+ . inner ( )
647+ . as_any ( )
648+ . downcast_ref :: < PythonFunctionAggregateUDF > ( )
649+ else {
650+ return Ok ( false ) ;
651+ } ;
652+
653+ Python :: attach ( |py| -> Result < bool > {
654+ let bytes = encode_python_agg_udf ( py, py_udf)
655+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
656+ buf. extend_from_slice ( PY_AGG_UDF_MAGIC ) ;
657+ buf. extend_from_slice ( & bytes) ;
658+ Ok ( true )
659+ } )
660+ }
661+
662+ pub ( crate ) fn try_decode_python_agg_udf ( buf : & [ u8 ] ) -> Result < Option < Arc < AggregateUDF > > > {
663+ if buf. is_empty ( ) || !buf. starts_with ( PY_AGG_UDF_MAGIC ) {
664+ return Ok ( None ) ;
665+ }
666+ let payload = & buf[ PY_AGG_UDF_MAGIC . len ( ) ..] ;
667+
668+ Python :: attach ( |py| -> Result < Option < Arc < AggregateUDF > > > {
669+ let udf = decode_python_agg_udf ( py, payload)
670+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
671+ Ok ( Some ( Arc :: new ( AggregateUDF :: new_from_impl ( udf) ) ) )
672+ } )
673+ }
674+
675+ fn encode_python_agg_udf ( py : Python < ' _ > , udf : & PythonFunctionAggregateUDF ) -> PyResult < Vec < u8 > > {
676+ let cloudpickle = py. import ( "cloudpickle" ) ?;
677+
678+ let signature = AggregateUDFImpl :: signature ( udf) ;
679+ let input_dtypes: Vec < arrow:: datatypes:: DataType > = match & signature. type_signature {
680+ TypeSignature :: Exact ( types) => types. clone ( ) ,
681+ other => {
682+ return Err ( pyo3:: exceptions:: PyValueError :: new_err ( format ! (
683+ "PythonFunctionAggregateUDF expected Signature::Exact, got {other:?}"
684+ ) ) ) ;
685+ }
686+ } ;
687+ let input_fields: Vec < Field > = input_dtypes
688+ . into_iter ( )
689+ . enumerate ( )
690+ . map ( |( i, dt) | Field :: new ( format ! ( "arg_{i}" ) , dt, true ) )
691+ . collect ( ) ;
692+ let input_schema_bytes = schema_to_ipc_bytes ( & Schema :: new ( input_fields) )
693+ . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
694+
695+ let return_schema = Schema :: new ( vec ! [ Field :: new( "result" , udf. return_type( ) . clone( ) , true ) ] ) ;
696+ let return_schema_bytes = schema_to_ipc_bytes ( & return_schema)
697+ . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
698+
699+ let state_fields: Vec < Field > = udf
700+ . state_fields_ref ( )
701+ . iter ( )
702+ . map ( |f| f. as_ref ( ) . clone ( ) )
703+ . collect ( ) ;
704+ let state_schema_bytes = schema_to_ipc_bytes ( & Schema :: new ( state_fields) )
705+ . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
706+
707+ let volatility = format ! ( "{:?}" , signature. volatility) . to_lowercase ( ) ;
708+
709+ let payload = PyTuple :: new (
710+ py,
711+ [
712+ AggregateUDFImpl :: name ( udf) . into_pyobject ( py) ?. into_any ( ) ,
713+ udf. accumulator ( ) . bind ( py) . clone ( ) . into_any ( ) ,
714+ PyBytes :: new ( py, & input_schema_bytes) . into_any ( ) ,
715+ PyBytes :: new ( py, & return_schema_bytes) . into_any ( ) ,
716+ PyBytes :: new ( py, & state_schema_bytes) . into_any ( ) ,
717+ volatility. into_pyobject ( py) ?. into_any ( ) ,
718+ ] ,
719+ ) ?;
720+
721+ let blob = cloudpickle. call_method1 ( "dumps" , ( payload, ) ) ?;
722+ blob. extract :: < Vec < u8 > > ( )
723+ }
724+
725+ fn decode_python_agg_udf ( py : Python < ' _ > , payload : & [ u8 ] ) -> PyResult < PythonFunctionAggregateUDF > {
726+ let cloudpickle = py. import ( "cloudpickle" ) ?;
727+
728+ let tuple = cloudpickle
729+ . call_method1 ( "loads" , ( PyBytes :: new ( py, payload) , ) ) ?
730+ . cast_into :: < PyTuple > ( ) ?;
731+
732+ let name: String = tuple. get_item ( 0 ) ?. extract ( ) ?;
733+ let accumulator: Py < PyAny > = tuple. get_item ( 1 ) ?. unbind ( ) ;
734+ let input_schema_bytes: Vec < u8 > = tuple. get_item ( 2 ) ?. extract ( ) ?;
735+ let return_schema_bytes: Vec < u8 > = tuple. get_item ( 3 ) ?. extract ( ) ?;
736+ let state_schema_bytes: Vec < u8 > = tuple. get_item ( 4 ) ?. extract ( ) ?;
737+ let volatility_str: String = tuple. get_item ( 5 ) ?. extract ( ) ?;
738+
739+ let input_schema = schema_from_ipc_bytes ( & input_schema_bytes)
740+ . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
741+ let input_types: Vec < arrow:: datatypes:: DataType > = input_schema
742+ . fields ( )
743+ . iter ( )
744+ . map ( |f| f. data_type ( ) . clone ( ) )
745+ . collect ( ) ;
746+
747+ let return_schema = schema_from_ipc_bytes ( & return_schema_bytes)
748+ . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
749+ let return_type = return_schema
750+ . fields ( )
751+ . first ( )
752+ . ok_or_else ( || {
753+ pyo3:: exceptions:: PyValueError :: new_err (
754+ "PythonFunctionAggregateUDF return schema must contain exactly one field" ,
755+ )
756+ } ) ?
757+ . data_type ( )
758+ . clone ( ) ;
759+
760+ let state_schema = schema_from_ipc_bytes ( & state_schema_bytes)
761+ . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
762+ let state_types: Vec < arrow:: datatypes:: DataType > = state_schema
763+ . fields ( )
764+ . iter ( )
765+ . map ( |f| f. data_type ( ) . clone ( ) )
766+ . collect ( ) ;
767+
768+ let volatility = datafusion_python_util:: parse_volatility ( & volatility_str)
769+ . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( format ! ( "{e}" ) ) ) ?;
770+
771+ Ok ( PythonFunctionAggregateUDF :: from_parts (
772+ name,
773+ accumulator,
774+ input_types,
775+ return_type,
776+ state_types,
777+ volatility,
778+ ) )
779+ }
0 commit comments