Skip to content

Commit f25aa87

Browse files
timsaucerclaude
andcommitted
test(ffi): strict-mode Expr round-trip with an FFI-capsule scalar UDF
The strict-mode wire format ships Python-defined UDFs by name; the existing coverage only exercised that path with `PythonFunctionScalarUDF` instances. This test uses the FFI `IsNullUDF` so the codec falls through `try_encode_python_scalar_udf` (downcast fails) and the default by-name encoding actually runs. Covers explicit `Expr.to_bytes` / `from_bytes` plus the `pickle.dumps` / `pickle.loads` route via the sender / worker thread-local slots, and asserts the strict blob is materially smaller than an inline Python-UDF blob — catching a regression that silently falls back to inlining for a non-Python UDF. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 310080c commit f25aa87

1 file changed

Lines changed: 127 additions & 0 deletions

File tree

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Strict-mode Expr round-trip with an FFI-capsule scalar UDF.
19+
20+
Verifies the by-name path: an FFI-imported UDF (no
21+
``PythonFunctionScalarUDF`` downcast on the codec) serializes by name
22+
and resolves from the receiver's function registry on decode. Covers
23+
both the explicit ``Expr.to_bytes(ctx)`` / ``Expr.from_bytes(ctx=...)``
24+
API and the ``pickle.dumps`` / ``pickle.loads`` route through the
25+
sender / worker context slots.
26+
"""
27+
28+
from __future__ import annotations
29+
30+
import pickle
31+
32+
import pyarrow as pa
33+
import pytest
34+
from datafusion import Expr, SessionContext, col, udf
35+
from datafusion.ipc import (
36+
clear_sender_ctx,
37+
clear_worker_ctx,
38+
set_sender_ctx,
39+
set_worker_ctx,
40+
)
41+
from datafusion_ffi_example import IsNullUDF
42+
43+
44+
@pytest.fixture(autouse=True)
45+
def _reset_thread_locals():
46+
"""Ensure no sender / worker context leaks across tests."""
47+
clear_worker_ctx()
48+
clear_sender_ctx()
49+
yield
50+
clear_worker_ctx()
51+
clear_sender_ctx()
52+
53+
54+
def _strict_session_with_ffi_udf():
55+
"""Build a strict-mode session with the FFI ``IsNullUDF`` registered."""
56+
ctx = SessionContext().with_python_udf_inlining(enabled=False)
57+
my_udf = udf(IsNullUDF())
58+
ctx.register_udf(my_udf)
59+
return ctx, my_udf
60+
61+
62+
def test_strict_ffi_udf_expr_roundtrip_via_to_bytes():
63+
"""Strict-mode encode emits a by-name payload; receiver resolves
64+
``my_custom_is_null`` from its registered functions and the decoded
65+
expression evaluates to the same result as the original."""
66+
sender, my_udf = _strict_session_with_ffi_udf()
67+
receiver, _ = _strict_session_with_ffi_udf()
68+
69+
expr = my_udf(col("a"))
70+
blob = expr.to_bytes(sender)
71+
restored = Expr.from_bytes(blob, ctx=receiver)
72+
73+
assert "my_custom_is_null" in restored.canonical_name()
74+
75+
batch = pa.RecordBatch.from_arrays(
76+
[pa.array([1, 2, None, 4], type=pa.int64())], names=["a"]
77+
)
78+
receiver.register_record_batches("t", [[batch]])
79+
out = receiver.table("t").select(restored.alias("r")).collect()
80+
expected = pa.array([False, False, True, False], type=pa.bool_())
81+
assert out[0].column(0) == expected
82+
83+
84+
def test_strict_ffi_udf_pickle_roundtrip_via_thread_locals():
85+
"""Driver installs a strict sender context; worker installs a
86+
matching strict receiver. ``pickle.dumps`` / ``pickle.loads`` route
87+
through them and the FFI UDF resolves by name on decode."""
88+
sender, my_udf = _strict_session_with_ffi_udf()
89+
receiver, _ = _strict_session_with_ffi_udf()
90+
91+
expr = my_udf(col("a"))
92+
93+
set_sender_ctx(sender)
94+
try:
95+
blob = pickle.dumps(expr)
96+
finally:
97+
clear_sender_ctx()
98+
99+
set_worker_ctx(receiver)
100+
try:
101+
restored = pickle.loads(blob)
102+
finally:
103+
clear_worker_ctx()
104+
105+
assert "my_custom_is_null" in restored.canonical_name()
106+
107+
108+
def test_strict_ffi_udf_smaller_than_inline_python_udf():
109+
"""Sanity-check the wire size claim: strict-mode FFI UDF bytes are
110+
a small by-name payload, dramatically smaller than the inline form
111+
of a Python UDF with the same arity. Confirms the encode path
112+
actually took the by-name branch instead of falling through to an
113+
inline path."""
114+
sender, my_udf = _strict_session_with_ffi_udf()
115+
ffi_blob = my_udf(col("a")).to_bytes(sender)
116+
117+
inline_ctx = SessionContext()
118+
py_udf = udf(
119+
lambda arr: pa.array([v.as_py() is None for v in arr]),
120+
[pa.int64()],
121+
pa.bool_(),
122+
volatility="immutable",
123+
name="py_is_null",
124+
)
125+
py_blob = py_udf(col("a")).to_bytes(inline_ctx)
126+
127+
assert len(ffi_blob) < len(py_blob) // 4

0 commit comments

Comments
 (0)