Skip to content

Commit fac2c24

Browse files
timsaucerclaude
andcommitted
Change make_map to accept a Python dictionary
make_map now takes a dict for the common case and also supports separate keys/values lists for column expressions. Non-Expr keys and values are automatically converted to literals. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4827528 commit fac2c24

File tree

2 files changed

+70
-68
lines changed

2 files changed

+70
-68
lines changed

python/datafusion/functions.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from __future__ import annotations
2020

21-
import builtins
2221
from typing import TYPE_CHECKING, Any
2322

2423
import pyarrow as pa
@@ -3348,29 +3347,47 @@ def empty(array: Expr) -> Expr:
33483347
# map functions
33493348

33503349

3351-
def make_map(*args: Expr) -> Expr:
3352-
"""Returns a map created from key and value expressions.
3350+
def make_map(
3351+
data: dict[Any, Any] | None = None,
3352+
keys: list[Any] | None = None,
3353+
values: list[Any] | None = None,
3354+
) -> Expr:
3355+
"""Returns a map expression.
3356+
3357+
Can be called with either a Python dictionary or separate ``keys``
3358+
and ``values`` lists. Keys and values that are not already
3359+
:py:class:`~datafusion.expr.Expr` are automatically converted to
3360+
literal expressions.
33533361
3354-
Accepts an even number of arguments, alternating between keys and values.
3355-
For example, ``make_map(k1, v1, k2, v2)`` creates a map ``{k1: v1, k2: v2}``.
3362+
Args:
3363+
data: A Python dictionary of key-value pairs.
3364+
keys: A list of keys (use with ``values`` for column expressions).
3365+
values: A list of values (use with ``keys``).
33563366
33573367
Examples:
33583368
>>> ctx = dfn.SessionContext()
33593369
>>> df = ctx.from_pydict({"a": [1]})
33603370
>>> result = df.select(
3361-
... dfn.functions.make_map(
3362-
... dfn.lit("a"), dfn.lit(1),
3363-
... dfn.lit("b"), dfn.lit(2),
3364-
... ).alias("map"))
3371+
... dfn.functions.make_map({"a": 1, "b": 2}).alias("map"))
33653372
>>> result.collect_column("map")[0].as_py()
33663373
[('a', 1), ('b', 2)]
33673374
"""
3368-
if len(args) % 2 != 0:
3369-
msg = "make_map requires an even number of arguments"
3375+
if data is not None:
3376+
if keys is not None or values is not None:
3377+
msg = "Cannot specify both data and keys/values"
3378+
raise ValueError(msg)
3379+
key_list = list(data.keys())
3380+
value_list = list(data.values())
3381+
elif keys is not None and values is not None:
3382+
key_list = keys
3383+
value_list = values
3384+
else:
3385+
msg = "Must specify either data or both keys and values"
33703386
raise ValueError(msg)
3371-
keys = [args[i].expr for i in builtins.range(0, len(args), 2)]
3372-
values = [args[i].expr for i in builtins.range(1, len(args), 2)]
3373-
return Expr(f.make_map(keys, values))
3387+
3388+
key_exprs = [k if isinstance(k, Expr) else Expr.literal(k) for k in key_list]
3389+
val_exprs = [v if isinstance(v, Expr) else Expr.literal(v) for v in value_list]
3390+
return Expr(f.make_map([k.expr for k in key_exprs], [v.expr for v in val_exprs]))
33743391

33753392

33763393
def map_keys(map: Expr) -> Expr:
@@ -3381,10 +3398,7 @@ def map_keys(map: Expr) -> Expr:
33813398
>>> df = ctx.from_pydict({"a": [1]})
33823399
>>> result = df.select(
33833400
... dfn.functions.map_keys(
3384-
... dfn.functions.make_map(
3385-
... dfn.lit("x"), dfn.lit(1),
3386-
... dfn.lit("y"), dfn.lit(2),
3387-
... )
3401+
... dfn.functions.make_map({"x": 1, "y": 2})
33883402
... ).alias("keys"))
33893403
>>> result.collect_column("keys")[0].as_py()
33903404
['x', 'y']
@@ -3400,10 +3414,7 @@ def map_values(map: Expr) -> Expr:
34003414
>>> df = ctx.from_pydict({"a": [1]})
34013415
>>> result = df.select(
34023416
... dfn.functions.map_values(
3403-
... dfn.functions.make_map(
3404-
... dfn.lit("x"), dfn.lit(1),
3405-
... dfn.lit("y"), dfn.lit(2),
3406-
... )
3417+
... dfn.functions.make_map({"x": 1, "y": 2})
34073418
... ).alias("vals"))
34083419
>>> result.collect_column("vals")[0].as_py()
34093420
[1, 2]
@@ -3419,10 +3430,7 @@ def map_extract(map: Expr, key: Expr) -> Expr:
34193430
>>> df = ctx.from_pydict({"a": [1]})
34203431
>>> result = df.select(
34213432
... dfn.functions.map_extract(
3422-
... dfn.functions.make_map(
3423-
... dfn.lit("x"), dfn.lit(1),
3424-
... dfn.lit("y"), dfn.lit(2),
3425-
... ),
3433+
... dfn.functions.make_map({"x": 1, "y": 2}),
34263434
... dfn.lit("x"),
34273435
... ).alias("val"))
34283436
>>> result.collect_column("val")[0].as_py()
@@ -3439,10 +3447,7 @@ def map_entries(map: Expr) -> Expr:
34393447
>>> df = ctx.from_pydict({"a": [1]})
34403448
>>> result = df.select(
34413449
... dfn.functions.map_entries(
3442-
... dfn.functions.make_map(
3443-
... dfn.lit("x"), dfn.lit(1),
3444-
... dfn.lit("y"), dfn.lit(2),
3445-
... )
3450+
... dfn.functions.make_map({"x": 1, "y": 2})
34463451
... ).alias("entries"))
34473452
>>> result.collect_column("entries")[0].as_py()
34483453
[{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}]

python/tests/test_functions.py

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -673,32 +673,50 @@ def test_make_map():
673673
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
674674
df = ctx.create_dataframe([[batch]])
675675

676+
result = df.select(f.make_map({"x": 1, "y": 2}).alias("map")).collect()[0].column(0)
677+
assert result[0].as_py() == [("x", 1), ("y", 2)]
678+
679+
680+
def test_make_map_with_expr_values():
681+
ctx = SessionContext()
682+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
683+
df = ctx.create_dataframe([[batch]])
684+
676685
result = (
677-
df.select(
678-
f.make_map(
679-
literal("x"),
680-
literal(1),
681-
literal("y"),
682-
literal(2),
683-
).alias("map")
684-
)
686+
df.select(f.make_map({"x": literal(1), "y": literal(2)}).alias("map"))
685687
.collect()[0]
686688
.column(0)
687689
)
688690
assert result[0].as_py() == [("x", 1), ("y", 2)]
689691

690692

691-
def test_make_map_odd_args():
692-
with pytest.raises(ValueError, match="even number of arguments"):
693-
f.make_map(literal("x"), literal(1), literal("y"))
693+
def test_make_map_with_column_data():
694+
ctx = SessionContext()
695+
batch = pa.RecordBatch.from_arrays(
696+
[
697+
pa.array(["k1", "k2", "k3"]),
698+
pa.array([10, 20, 30]),
699+
],
700+
names=["keys", "vals"],
701+
)
702+
df = ctx.create_dataframe([[batch]])
703+
704+
m = f.make_map(keys=[column("keys")], values=[column("vals")])
705+
result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0)
706+
for i, expected in enumerate(["k1", "k2", "k3"]):
707+
assert result[i].as_py() == [expected]
708+
709+
result = df.select(f.map_values(m).alias("v")).collect()[0].column(0)
710+
for i, expected in enumerate([10, 20, 30]):
711+
assert result[i].as_py() == [expected]
694712

695713

696714
def test_map_keys():
697715
ctx = SessionContext()
698716
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
699717
df = ctx.create_dataframe([[batch]])
700718

701-
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
719+
m = f.make_map({"x": 1, "y": 2})
702720
result = df.select(f.map_keys(m).alias("keys")).collect()[0].column(0)
703721
assert result[0].as_py() == ["x", "y"]
704722

@@ -708,7 +726,7 @@ def test_map_values():
708726
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
709727
df = ctx.create_dataframe([[batch]])
710728

711-
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
729+
m = f.make_map({"x": 1, "y": 2})
712730
result = df.select(f.map_values(m).alias("vals")).collect()[0].column(0)
713731
assert result[0].as_py() == [1, 2]
714732

@@ -718,7 +736,7 @@ def test_map_extract():
718736
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
719737
df = ctx.create_dataframe([[batch]])
720738

721-
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
739+
m = f.make_map({"x": 1, "y": 2})
722740
result = (
723741
df.select(f.map_extract(m, literal("x")).alias("val")).collect()[0].column(0)
724742
)
@@ -730,7 +748,7 @@ def test_map_extract_missing_key():
730748
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
731749
df = ctx.create_dataframe([[batch]])
732750

733-
m = f.make_map(literal("x"), literal(1))
751+
m = f.make_map({"x": 1})
734752
result = (
735753
df.select(f.map_extract(m, literal("z")).alias("val")).collect()[0].column(0)
736754
)
@@ -742,7 +760,7 @@ def test_map_entries():
742760
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
743761
df = ctx.create_dataframe([[batch]])
744762

745-
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
763+
m = f.make_map({"x": 1, "y": 2})
746764
result = df.select(f.map_entries(m).alias("entries")).collect()[0].column(0)
747765
assert result[0].as_py() == [
748766
{"key": "x", "value": 1},
@@ -755,34 +773,13 @@ def test_element_at():
755773
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
756774
df = ctx.create_dataframe([[batch]])
757775

758-
m = f.make_map(literal("a"), literal(10), literal("b"), literal(20))
776+
m = f.make_map({"a": 10, "b": 20})
759777
result = (
760778
df.select(f.element_at(m, literal("b")).alias("val")).collect()[0].column(0)
761779
)
762780
assert result[0].as_py() == [20]
763781

764782

765-
def test_map_functions_with_column_data():
766-
ctx = SessionContext()
767-
batch = pa.RecordBatch.from_arrays(
768-
[
769-
pa.array(["k1", "k2", "k3"]),
770-
pa.array([10, 20, 30]),
771-
],
772-
names=["keys", "vals"],
773-
)
774-
df = ctx.create_dataframe([[batch]])
775-
776-
m = f.make_map(column("keys"), column("vals"))
777-
result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0)
778-
for i, expected in enumerate(["k1", "k2", "k3"]):
779-
assert result[i].as_py() == [expected]
780-
781-
result = df.select(f.map_values(m).alias("v")).collect()[0].column(0)
782-
for i, expected in enumerate([10, 20, 30]):
783-
assert result[i].as_py() == [expected]
784-
785-
786783
@pytest.mark.parametrize(
787784
("function", "expected_result"),
788785
[

0 commit comments

Comments
 (0)