Skip to content

Commit 14af180

Browse files
timsaucerclaude
andcommitted
Add map functions (make_map, map_keys, map_values, map_extract, map_entries, element_at)
Closes #1448 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ad8d41f commit 14af180

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

crates/core/src/functions.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ fn array_cat(exprs: Vec<PyExpr>) -> PyExpr {
9393
array_concat(exprs)
9494
}
9595

96+
#[pyfunction]
97+
fn make_map(keys: Vec<PyExpr>, values: Vec<PyExpr>) -> PyExpr {
98+
let keys = keys.into_iter().map(|x| x.into()).collect();
99+
let values = values.into_iter().map(|x| x.into()).collect();
100+
datafusion::functions_nested::map::map(keys, values).into()
101+
}
102+
103+
#[pyfunction]
104+
fn element_at(map: PyExpr, key: PyExpr) -> PyExpr {
105+
datafusion::functions_nested::expr_fn::map_extract(map.into(), key.into()).into()
106+
}
107+
96108
#[pyfunction]
97109
#[pyo3(signature = (array, element, index=None))]
98110
fn array_position(array: PyExpr, element: PyExpr, index: Option<i64>) -> PyExpr {
@@ -665,6 +677,12 @@ array_fn!(cardinality, array);
665677
array_fn!(flatten, array);
666678
array_fn!(range, start stop step);
667679

680+
// Map Functions
681+
array_fn!(map_keys, map);
682+
array_fn!(map_values, map);
683+
array_fn!(map_extract, map key);
684+
array_fn!(map_entries, map);
685+
668686
aggregate_function!(array_agg);
669687
aggregate_function!(max);
670688
aggregate_function!(min);
@@ -1124,6 +1142,14 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
11241142
m.add_wrapped(wrap_pyfunction!(flatten))?;
11251143
m.add_wrapped(wrap_pyfunction!(cardinality))?;
11261144

1145+
// Map Functions
1146+
m.add_wrapped(wrap_pyfunction!(make_map))?;
1147+
m.add_wrapped(wrap_pyfunction!(map_keys))?;
1148+
m.add_wrapped(wrap_pyfunction!(map_values))?;
1149+
m.add_wrapped(wrap_pyfunction!(map_extract))?;
1150+
m.add_wrapped(wrap_pyfunction!(map_entries))?;
1151+
m.add_wrapped(wrap_pyfunction!(element_at))?;
1152+
11271153
// Window Functions
11281154
m.add_wrapped(wrap_pyfunction!(lead))?;
11291155
m.add_wrapped(wrap_pyfunction!(lag))?;

python/datafusion/functions.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from __future__ import annotations
2020

21+
import builtins
2122
from typing import TYPE_CHECKING, Any
2223

2324
import pyarrow as pa
@@ -137,6 +138,7 @@
137138
"degrees",
138139
"dense_rank",
139140
"digest",
141+
"element_at",
140142
"empty",
141143
"encode",
142144
"ends_with",
@@ -200,6 +202,11 @@
200202
"make_array",
201203
"make_date",
202204
"make_list",
205+
"make_map",
206+
"map_entries",
207+
"map_extract",
208+
"map_keys",
209+
"map_values",
203210
"max",
204211
"md5",
205212
"mean",
@@ -3338,6 +3345,120 @@ def empty(array: Expr) -> Expr:
33383345
return array_empty(array)
33393346

33403347

3348+
# map functions
3349+
3350+
3351+
def make_map(*args: Expr) -> Expr:
3352+
"""Returns a map created from key and value expressions.
3353+
3354+
Accepts an even number of arguments, alternating between keys and values.
3355+
For example, ``make_map(k1, v1, k2, v2)`` creates a map ``{k1: v1, k2: v2}``.
3356+
3357+
Examples:
3358+
>>> ctx = dfn.SessionContext()
3359+
>>> df = ctx.from_pydict({"a": [1]})
3360+
>>> result = df.select(
3361+
... dfn.functions.make_map(
3362+
... dfn.lit("a"), dfn.lit(1),
3363+
... dfn.lit("b"), dfn.lit(2),
3364+
... ).alias("map"))
3365+
>>> result.collect_column("map")[0].as_py()
3366+
[('a', 1), ('b', 2)]
3367+
"""
3368+
if len(args) % 2 != 0:
3369+
msg = "make_map requires an even number of arguments"
3370+
raise ValueError(msg)
3371+
keys = [args[i].expr for i in builtins.range(0, len(args), 2)]
3372+
values = [args[i].expr for i in builtins.range(1, len(args), 2)]
3373+
return Expr(f.make_map(keys, values))
3374+
3375+
3376+
def map_keys(map: Expr) -> Expr:
3377+
"""Returns a list of all keys in the map.
3378+
3379+
Examples:
3380+
>>> ctx = dfn.SessionContext()
3381+
>>> df = ctx.from_pydict({"a": [1]})
3382+
>>> result = df.select(
3383+
... dfn.functions.map_keys(
3384+
... dfn.functions.make_map(
3385+
... dfn.lit("x"), dfn.lit(1),
3386+
... dfn.lit("y"), dfn.lit(2),
3387+
... )
3388+
... ).alias("keys"))
3389+
>>> result.collect_column("keys")[0].as_py()
3390+
['x', 'y']
3391+
"""
3392+
return Expr(f.map_keys(map.expr))
3393+
3394+
3395+
def map_values(map: Expr) -> Expr:
3396+
"""Returns a list of all values in the map.
3397+
3398+
Examples:
3399+
>>> ctx = dfn.SessionContext()
3400+
>>> df = ctx.from_pydict({"a": [1]})
3401+
>>> result = df.select(
3402+
... dfn.functions.map_values(
3403+
... dfn.functions.make_map(
3404+
... dfn.lit("x"), dfn.lit(1),
3405+
... dfn.lit("y"), dfn.lit(2),
3406+
... )
3407+
... ).alias("vals"))
3408+
>>> result.collect_column("vals")[0].as_py()
3409+
[1, 2]
3410+
"""
3411+
return Expr(f.map_values(map.expr))
3412+
3413+
3414+
def map_extract(map: Expr, key: Expr) -> Expr:
3415+
"""Returns the value for the given key in the map, or an empty list if absent.
3416+
3417+
Examples:
3418+
>>> ctx = dfn.SessionContext()
3419+
>>> df = ctx.from_pydict({"a": [1]})
3420+
>>> result = df.select(
3421+
... dfn.functions.map_extract(
3422+
... dfn.functions.make_map(
3423+
... dfn.lit("x"), dfn.lit(1),
3424+
... dfn.lit("y"), dfn.lit(2),
3425+
... ),
3426+
... dfn.lit("x"),
3427+
... ).alias("val"))
3428+
>>> result.collect_column("val")[0].as_py()
3429+
[1]
3430+
"""
3431+
return Expr(f.map_extract(map.expr, key.expr))
3432+
3433+
3434+
def map_entries(map: Expr) -> Expr:
3435+
"""Returns a list of all entries (key-value struct pairs) in the map.
3436+
3437+
Examples:
3438+
>>> ctx = dfn.SessionContext()
3439+
>>> df = ctx.from_pydict({"a": [1]})
3440+
>>> result = df.select(
3441+
... dfn.functions.map_entries(
3442+
... dfn.functions.make_map(
3443+
... dfn.lit("x"), dfn.lit(1),
3444+
... dfn.lit("y"), dfn.lit(2),
3445+
... )
3446+
... ).alias("entries"))
3447+
>>> result.collect_column("entries")[0].as_py()
3448+
[{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}]
3449+
"""
3450+
return Expr(f.map_entries(map.expr))
3451+
3452+
3453+
def element_at(map: Expr, key: Expr) -> Expr:
3454+
"""Returns the value for the given key in the map, or an empty list if absent.
3455+
3456+
See Also:
3457+
This is an alias for :py:func:`map_extract`.
3458+
"""
3459+
return map_extract(map, key)
3460+
3461+
33413462
# aggregate functions
33423463
def approx_distinct(
33433464
expression: Expr,

0 commit comments

Comments
 (0)