|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +"""Strict-mode Expr round-trip with an FFI-capsule scalar UDF. |
| 19 | +
|
| 20 | +Verifies the by-name path: an FFI-imported UDF (no |
| 21 | +``PythonFunctionScalarUDF`` downcast on the codec) serializes by name |
| 22 | +and resolves from the receiver's function registry on decode. Covers |
| 23 | +both the explicit ``Expr.to_bytes(ctx)`` / ``Expr.from_bytes(ctx=...)`` |
| 24 | +API and the ``pickle.dumps`` / ``pickle.loads`` route through the |
| 25 | +sender / worker context slots. |
| 26 | +""" |
| 27 | + |
| 28 | +from __future__ import annotations |
| 29 | + |
| 30 | +import pickle |
| 31 | + |
| 32 | +import pyarrow as pa |
| 33 | +import pytest |
| 34 | +from datafusion import Expr, SessionContext, col, udf |
| 35 | +from datafusion.ipc import ( |
| 36 | + clear_sender_ctx, |
| 37 | + clear_worker_ctx, |
| 38 | + set_sender_ctx, |
| 39 | + set_worker_ctx, |
| 40 | +) |
| 41 | +from datafusion_ffi_example import IsNullUDF |
| 42 | + |
| 43 | + |
| 44 | +@pytest.fixture(autouse=True) |
| 45 | +def _reset_thread_locals(): |
| 46 | + """Ensure no sender / worker context leaks across tests.""" |
| 47 | + clear_worker_ctx() |
| 48 | + clear_sender_ctx() |
| 49 | + yield |
| 50 | + clear_worker_ctx() |
| 51 | + clear_sender_ctx() |
| 52 | + |
| 53 | + |
| 54 | +def _strict_session_with_ffi_udf(): |
| 55 | + """Build a strict-mode session with the FFI ``IsNullUDF`` registered.""" |
| 56 | + ctx = SessionContext().with_python_udf_inlining(enabled=False) |
| 57 | + my_udf = udf(IsNullUDF()) |
| 58 | + ctx.register_udf(my_udf) |
| 59 | + return ctx, my_udf |
| 60 | + |
| 61 | + |
| 62 | +def test_strict_ffi_udf_expr_roundtrip_via_to_bytes(): |
| 63 | + """Strict-mode encode emits a by-name payload; receiver resolves |
| 64 | + ``my_custom_is_null`` from its registered functions and the decoded |
| 65 | + expression evaluates to the same result as the original.""" |
| 66 | + sender, my_udf = _strict_session_with_ffi_udf() |
| 67 | + receiver, _ = _strict_session_with_ffi_udf() |
| 68 | + |
| 69 | + expr = my_udf(col("a")) |
| 70 | + blob = expr.to_bytes(sender) |
| 71 | + restored = Expr.from_bytes(blob, ctx=receiver) |
| 72 | + |
| 73 | + assert "my_custom_is_null" in restored.canonical_name() |
| 74 | + |
| 75 | + batch = pa.RecordBatch.from_arrays( |
| 76 | + [pa.array([1, 2, None, 4], type=pa.int64())], names=["a"] |
| 77 | + ) |
| 78 | + receiver.register_record_batches("t", [[batch]]) |
| 79 | + out = receiver.table("t").select(restored.alias("r")).collect() |
| 80 | + expected = pa.array([False, False, True, False], type=pa.bool_()) |
| 81 | + assert out[0].column(0) == expected |
| 82 | + |
| 83 | + |
| 84 | +def test_strict_ffi_udf_pickle_roundtrip_via_thread_locals(): |
| 85 | + """Driver installs a strict sender context; worker installs a |
| 86 | + matching strict receiver. ``pickle.dumps`` / ``pickle.loads`` route |
| 87 | + through them and the FFI UDF resolves by name on decode.""" |
| 88 | + sender, my_udf = _strict_session_with_ffi_udf() |
| 89 | + receiver, _ = _strict_session_with_ffi_udf() |
| 90 | + |
| 91 | + expr = my_udf(col("a")) |
| 92 | + |
| 93 | + set_sender_ctx(sender) |
| 94 | + try: |
| 95 | + blob = pickle.dumps(expr) |
| 96 | + finally: |
| 97 | + clear_sender_ctx() |
| 98 | + |
| 99 | + set_worker_ctx(receiver) |
| 100 | + try: |
| 101 | + restored = pickle.loads(blob) |
| 102 | + finally: |
| 103 | + clear_worker_ctx() |
| 104 | + |
| 105 | + assert "my_custom_is_null" in restored.canonical_name() |
| 106 | + |
| 107 | + |
| 108 | +def test_strict_ffi_udf_smaller_than_inline_python_udf(): |
| 109 | + """Sanity-check the wire size claim: strict-mode FFI UDF bytes are |
| 110 | + a small by-name payload, dramatically smaller than the inline form |
| 111 | + of a Python UDF with the same arity. Confirms the encode path |
| 112 | + actually took the by-name branch instead of falling through to an |
| 113 | + inline path.""" |
| 114 | + sender, my_udf = _strict_session_with_ffi_udf() |
| 115 | + ffi_blob = my_udf(col("a")).to_bytes(sender) |
| 116 | + |
| 117 | + inline_ctx = SessionContext() |
| 118 | + py_udf = udf( |
| 119 | + lambda arr: pa.array([v.as_py() is None for v in arr]), |
| 120 | + [pa.int64()], |
| 121 | + pa.bool_(), |
| 122 | + volatility="immutable", |
| 123 | + name="py_is_null", |
| 124 | + ) |
| 125 | + py_blob = py_udf(col("a")).to_bytes(inline_ctx) |
| 126 | + |
| 127 | + assert len(ffi_blob) < len(py_blob) // 4 |
0 commit comments