Skip to content

Commit 96c3a14

Browse files
timsaucerclaude
andcommitted
feat: accept variadic field path in get_field
Upstream exposes both `get_field(expr, name)` and `get_field_path(expr, [names...])`, but both ultimately call the same scalar UDF with a base expression plus one or more name args. Collapse the Python surface into a single variadic `get_field(expr, *names)` that accepts either a one-step lookup or a path of names, dispatching through a single Rust binding. Note in `.ai/skills/check-upstream/SKILL.md` that `get_field_path` is covered by the variadic form so future audits do not flag it as a gap. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 95eee71 commit 96c3a14

4 files changed

Lines changed: 76 additions & 13 deletions

File tree

.ai/skills/check-upstream/SKILL.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,17 @@ The user may specify an area via `$ARGUMENTS`. If no area is specified or "all"
6666
- Python API: `python/datafusion/functions.py` — each function wraps a call to `datafusion._internal.functions`
6767
- Rust bindings: `crates/core/src/functions.rs``#[pyfunction]` definitions registered via `init_module()`
6868

69+
**Evaluated and not requiring separate Python exposure:**
70+
- `get_field_path` — already covered by `get_field(expr, *names)`, which takes a
71+
variadic field path and dispatches to the same underlying
72+
`functions::core::get_field` UDF as the upstream `get_field_path` helper.
73+
6974
**How to check:**
7075
1. Fetch the upstream scalar function documentation page
7176
2. Compare against functions listed in `python/datafusion/functions.py` (check the `__all__` list and function definitions)
7277
3. A function is covered if it exists in the Python API — it does NOT need a dedicated Rust `#[pyfunction]`. Many functions are aliases that reuse another function's Rust binding.
73-
4. Only report functions that are missing from the Python `__all__` list / function definitions
78+
4. Check against the "evaluated and not requiring exposure" list before flagging as a gap
79+
5. Only report functions that are missing from the Python `__all__` list / function definitions
7480

7581
### 2. Aggregate Functions
7682

crates/core/src/functions.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,10 +574,10 @@ expr_fn!(union_tag, arg1);
574574
expr_fn!(random);
575575

576576
#[pyfunction]
577-
fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr {
578-
functions::core::get_field()
579-
.call(vec![expr.into(), name.into()])
580-
.into()
577+
fn get_field(expr: PyExpr, names: Vec<PyExpr>) -> PyExpr {
578+
let mut args = vec![expr.into()];
579+
args.extend(names.into_iter().map(Into::into));
580+
functions::core::get_field().call(args).into()
581581
}
582582

583583
#[pyfunction]

python/datafusion/functions.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2727,14 +2727,24 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
27272727
return Expr(f.arrow_metadata(expr.expr, key.expr))
27282728

27292729

2730-
def get_field(expr: Expr, name: Expr | str) -> Expr:
2731-
"""Extracts a field from a struct or map by name.
2730+
def get_field(expr: Expr, *names: Expr | str) -> Expr:
2731+
"""Extracts a (possibly nested) field from a struct or map by name.
27322732
2733-
When the field name is a static string, the bracket operator
2734-
``expr["field"]`` is a convenient shorthand. Use ``get_field``
2735-
when the field name is a dynamic expression.
2733+
Pass one name for a single-level lookup, or several names to walk a path
2734+
of nested struct/map fields in a single ``get_field`` call. For a single
2735+
static-string name, ``expr["field"]`` is a convenient shorthand; use
2736+
``get_field`` when the field name is a dynamic
2737+
:py:class:`~datafusion.expr.Expr` or when traversing multiple levels at
2738+
once.
2739+
2740+
Args:
2741+
expr: The struct or map expression to read from.
2742+
*names: One or more field names (``str``) or expressions
2743+
(:py:class:`~datafusion.expr.Expr`).
27362744
27372745
Examples:
2746+
Single-level lookup:
2747+
27382748
>>> ctx = dfn.SessionContext()
27392749
>>> df = ctx.from_pydict({"a": [1], "b": [2]})
27402750
>>> df = df.with_column(
@@ -2756,10 +2766,26 @@ def get_field(expr: Expr, name: Expr | str) -> Expr:
27562766
... )
27572767
>>> result.collect_column("x_val")[0].as_py()
27582768
1
2769+
2770+
Multi-level lookup:
2771+
2772+
>>> df = df.with_column(
2773+
... "outer",
2774+
... dfn.functions.named_struct([("inner", dfn.col("s"))]),
2775+
... )
2776+
>>> result = df.select(
2777+
... dfn.functions.get_field(
2778+
... dfn.col("outer"), "inner", "x"
2779+
... ).alias("x_val")
2780+
... )
2781+
>>> result.collect_column("x_val")[0].as_py()
2782+
1
27592783
"""
2760-
if isinstance(name, str):
2761-
name = Expr.string_literal(name)
2762-
return Expr(f.get_field(expr.expr, name.expr))
2784+
if not names:
2785+
msg = "get_field requires at least one field name"
2786+
raise ValueError(msg)
2787+
resolved = [Expr.string_literal(n) if isinstance(n, str) else n for n in names]
2788+
return Expr(f.get_field(expr.expr, [n.expr for n in resolved]))
27632789

27642790

27652791
def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr:

python/tests/test_functions.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,6 +1957,37 @@ def test_get_field(df):
19571957
assert result.column(1) == pa.array([4, 5, 6])
19581958

19591959

1960+
def test_get_field_path(df):
1961+
df = df.with_column(
1962+
"outer",
1963+
f.named_struct(
1964+
[
1965+
(
1966+
"inner",
1967+
f.named_struct(
1968+
[
1969+
("x", column("a")),
1970+
("y", column("b")),
1971+
]
1972+
),
1973+
),
1974+
]
1975+
),
1976+
)
1977+
result = df.select(
1978+
f.get_field(column("outer"), "inner", "x").alias("x_val"),
1979+
f.get_field(column("outer"), "inner", "y").alias("y_val"),
1980+
).collect()[0]
1981+
1982+
assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view())
1983+
assert result.column(1) == pa.array([4, 5, 6])
1984+
1985+
1986+
def test_get_field_requires_a_name():
1987+
with pytest.raises(ValueError, match="at least one field name"):
1988+
f.get_field(column("s"))
1989+
1990+
19601991
def test_arrow_metadata():
19611992
ctx = SessionContext()
19621993
field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"})

0 commit comments

Comments
 (0)