|
77 | 77 |
|
78 | 78 | use std::sync::Arc; |
79 | 79 |
|
80 | | -use arrow::datatypes::SchemaRef; |
| 80 | +use arrow::datatypes::{Field, Schema, SchemaRef}; |
| 81 | +use arrow::pyarrow::ToPyArrow; |
| 82 | +use datafusion::arrow::pyarrow::FromPyArrow; |
81 | 83 | use datafusion::common::{Result, TableReference}; |
82 | 84 | use datafusion::datasource::TableProvider; |
83 | 85 | use datafusion::datasource::file_format::FileFormatFactory; |
84 | 86 | use datafusion::execution::TaskContext; |
85 | | -use datafusion::logical_expr::{AggregateUDF, Extension, LogicalPlan, ScalarUDF, WindowUDF}; |
| 87 | +use datafusion::logical_expr::{ |
| 88 | + AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, WindowUDF, |
| 89 | +}; |
86 | 90 | use datafusion::physical_expr::PhysicalExpr; |
87 | 91 | use datafusion::physical_plan::ExecutionPlan; |
88 | 92 | use datafusion_proto::logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec}; |
89 | 93 | use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; |
| 94 | +use pyo3::BoundObject; |
| 95 | +use pyo3::prelude::*; |
| 96 | +use pyo3::types::{PyBytes, PyTuple}; |
| 97 | + |
| 98 | +use crate::udf::PythonFunctionScalarUDF; |
90 | 99 |
|
91 | 100 | /// Wire-format prefix that tags a `fun_definition` payload as an |
92 | 101 | /// inlined Python scalar UDF (cloudpickled tuple of name, callable, |
93 | 102 | /// input schema, return field, volatility). Defined once here so |
94 | 103 | /// the encoder and decoder cannot drift. |
95 | | -#[allow(dead_code)] |
96 | 104 | pub(crate) const PY_SCALAR_UDF_MAGIC: &[u8] = b"DFPYUDF1"; |
97 | 105 |
|
98 | 106 | /// `LogicalExtensionCodec` parked on every `SessionContext`. Holds |
@@ -177,10 +185,16 @@ impl LogicalExtensionCodec for PythonLogicalCodec { |
177 | 185 | } |
178 | 186 |
|
179 | 187 | fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<()> { |
| 188 | + if try_encode_python_scalar_udf(node, buf)? { |
| 189 | + return Ok(()); |
| 190 | + } |
180 | 191 | self.inner.try_encode_udf(node, buf) |
181 | 192 | } |
182 | 193 |
|
183 | 194 | fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> { |
| 195 | + if let Some(udf) = try_decode_python_scalar_udf(buf)? { |
| 196 | + return Ok(udf); |
| 197 | + } |
184 | 198 | self.inner.try_decode_udf(name, buf) |
185 | 199 | } |
186 | 200 |
|
@@ -249,10 +263,16 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { |
249 | 263 | } |
250 | 264 |
|
251 | 265 | fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<()> { |
| 266 | + if try_encode_python_scalar_udf(node, buf)? { |
| 267 | + return Ok(()); |
| 268 | + } |
252 | 269 | self.inner.try_encode_udf(node, buf) |
253 | 270 | } |
254 | 271 |
|
255 | 272 | fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> { |
| 273 | + if let Some(udf) = try_decode_python_scalar_udf(buf)? { |
| 274 | + return Ok(udf); |
| 275 | + } |
256 | 276 | self.inner.try_decode_udf(name, buf) |
257 | 277 | } |
258 | 278 |
|
@@ -284,3 +304,126 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { |
284 | 304 | self.inner.try_decode_udwf(name, buf) |
285 | 305 | } |
286 | 306 | } |
| 307 | + |
| 308 | +// ============================================================================= |
| 309 | +// Shared Python scalar UDF encode / decode helpers |
| 310 | +// |
| 311 | +// Both `PythonLogicalCodec` and `PythonPhysicalCodec` consult these on |
| 312 | +// every `try_encode_udf` / `try_decode_udf` call. Same wire format on |
| 313 | +// both layers — a Python `ScalarUDF` referenced inside a `LogicalPlan` |
| 314 | +// or an `ExecutionPlan` round-trips identically. |
| 315 | +// ============================================================================= |
| 316 | + |
| 317 | +/// Encode a Python scalar UDF inline if `node` is one. Returns |
| 318 | +/// `Ok(true)` when the payload (`DFPYUDF1` prefix + cloudpickled |
| 319 | +/// tuple) was written and the caller should skip its inner codec. |
| 320 | +/// Returns `Ok(false)` for any non-Python UDF, signalling the caller |
| 321 | +/// to delegate to its `inner`. |
| 322 | +pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<bool> { |
| 323 | + let Some(py_udf) = node |
| 324 | + .inner() |
| 325 | + .as_any() |
| 326 | + .downcast_ref::<PythonFunctionScalarUDF>() |
| 327 | + else { |
| 328 | + return Ok(false); |
| 329 | + }; |
| 330 | + |
| 331 | + Python::attach(|py| -> Result<bool> { |
| 332 | + let bytes = encode_python_scalar_udf(py, py_udf) |
| 333 | + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; |
| 334 | + buf.extend_from_slice(PY_SCALAR_UDF_MAGIC); |
| 335 | + buf.extend_from_slice(&bytes); |
| 336 | + Ok(true) |
| 337 | + }) |
| 338 | +} |
| 339 | + |
| 340 | +/// Decode an inline Python scalar UDF payload. Returns `Ok(None)` |
| 341 | +/// when `buf` does not carry the `DFPYUDF1` prefix, signalling the |
| 342 | +/// caller to delegate to its `inner` codec (and eventually the |
| 343 | +/// `FunctionRegistry`). |
| 344 | +pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result<Option<Arc<ScalarUDF>>> { |
| 345 | + if buf.is_empty() || !buf.starts_with(PY_SCALAR_UDF_MAGIC) { |
| 346 | + return Ok(None); |
| 347 | + } |
| 348 | + let payload = &buf[PY_SCALAR_UDF_MAGIC.len()..]; |
| 349 | + |
| 350 | + Python::attach(|py| -> Result<Option<Arc<ScalarUDF>>> { |
| 351 | + let udf = decode_python_scalar_udf(py, payload) |
| 352 | + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; |
| 353 | + Ok(Some(Arc::new(ScalarUDF::new_from_impl(udf)))) |
| 354 | + }) |
| 355 | +} |
| 356 | + |
| 357 | +/// Build the cloudpickle payload for a `PythonFunctionScalarUDF`. |
| 358 | +/// |
| 359 | +/// Layout: `cloudpickle.dumps((name, func, input_schema_bytes, |
| 360 | +/// return_field, volatility_str))`. Input fields ride along as an |
| 361 | +/// IPC-encoded pyarrow Schema so they round-trip without extra |
| 362 | +/// plumbing. |
| 363 | +fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> PyResult<Vec<u8>> { |
| 364 | + let cloudpickle = py.import("cloudpickle")?; |
| 365 | + |
| 366 | + let input_schema = Schema::new(udf.input_fields().to_vec()); |
| 367 | + let pa_schema_obj = input_schema.to_pyarrow(py)?; |
| 368 | + let pa_schema = pa_schema_obj.into_bound(); |
| 369 | + let schema_bytes: Vec<u8> = pa_schema |
| 370 | + .call_method0("serialize")? |
| 371 | + .call_method0("to_pybytes")? |
| 372 | + .extract()?; |
| 373 | + |
| 374 | + let return_field_obj = udf.return_field().as_ref().to_pyarrow(py)?; |
| 375 | + let volatility = format!("{:?}", udf.volatility()).to_lowercase(); |
| 376 | + |
| 377 | + let payload = PyTuple::new( |
| 378 | + py, |
| 379 | + [ |
| 380 | + udf.name().into_pyobject(py)?.into_any(), |
| 381 | + udf.func().bind(py).clone().into_any(), |
| 382 | + PyBytes::new(py, &schema_bytes).into_any(), |
| 383 | + return_field_obj.into_bound(), |
| 384 | + volatility.into_pyobject(py)?.into_any(), |
| 385 | + ], |
| 386 | + )?; |
| 387 | + |
| 388 | + let blob = cloudpickle.call_method1("dumps", (payload,))?; |
| 389 | + blob.extract::<Vec<u8>>() |
| 390 | +} |
| 391 | + |
| 392 | +/// Inverse of [`encode_python_scalar_udf`]. |
| 393 | +fn decode_python_scalar_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionScalarUDF> { |
| 394 | + let cloudpickle = py.import("cloudpickle")?; |
| 395 | + let pyarrow = py.import("pyarrow")?; |
| 396 | + |
| 397 | + let tuple = cloudpickle |
| 398 | + .call_method1("loads", (PyBytes::new(py, payload),))? |
| 399 | + .cast_into::<PyTuple>()?; |
| 400 | + |
| 401 | + let name: String = tuple.get_item(0)?.extract()?; |
| 402 | + let func: Py<PyAny> = tuple.get_item(1)?.unbind(); |
| 403 | + let schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?; |
| 404 | + let return_field_py = tuple.get_item(3)?; |
| 405 | + let volatility_str: String = tuple.get_item(4)?.extract()?; |
| 406 | + |
| 407 | + let buffer = pyarrow.call_method1("py_buffer", (PyBytes::new(py, &schema_bytes),))?; |
| 408 | + let pa_schema = pyarrow |
| 409 | + .getattr("ipc")? |
| 410 | + .call_method1("read_schema", (buffer,))?; |
| 411 | + |
| 412 | + let schema = Schema::from_pyarrow_bound(&pa_schema) |
| 413 | + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?; |
| 414 | + let input_fields: Vec<Field> = schema.fields().iter().map(|f| f.as_ref().clone()).collect(); |
| 415 | + |
| 416 | + let return_field = Field::from_pyarrow_bound(&return_field_py) |
| 417 | + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?; |
| 418 | + |
| 419 | + let volatility = datafusion_python_util::parse_volatility(&volatility_str) |
| 420 | + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?; |
| 421 | + |
| 422 | + Ok(PythonFunctionScalarUDF::from_parts( |
| 423 | + name, |
| 424 | + func, |
| 425 | + input_fields, |
| 426 | + return_field, |
| 427 | + volatility, |
| 428 | + )) |
| 429 | +} |
0 commit comments