Skip to content

Commit cce1305

Browse files
timsaucerclaude
andcommitted
Add unit tests for map functions
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 14af180 commit cce1305

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

python/tests/test_functions.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,121 @@ def test_array_function_obj_tests(stmt, py_expr):
668668
assert a == b
669669

670670

671+
def test_make_map():
672+
ctx = SessionContext()
673+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
674+
df = ctx.create_dataframe([[batch]])
675+
676+
result = (
677+
df.select(
678+
f.make_map(
679+
literal("x"),
680+
literal(1),
681+
literal("y"),
682+
literal(2),
683+
).alias("map")
684+
)
685+
.collect()[0]
686+
.column(0)
687+
)
688+
assert result[0].as_py() == [("x", 1), ("y", 2)]
689+
690+
691+
def test_make_map_odd_args():
692+
with pytest.raises(ValueError, match="even number of arguments"):
693+
f.make_map(literal("x"), literal(1), literal("y"))
694+
695+
696+
def test_map_keys():
697+
ctx = SessionContext()
698+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
699+
df = ctx.create_dataframe([[batch]])
700+
701+
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
702+
result = df.select(f.map_keys(m).alias("keys")).collect()[0].column(0)
703+
assert result[0].as_py() == ["x", "y"]
704+
705+
706+
def test_map_values():
707+
ctx = SessionContext()
708+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
709+
df = ctx.create_dataframe([[batch]])
710+
711+
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
712+
result = df.select(f.map_values(m).alias("vals")).collect()[0].column(0)
713+
assert result[0].as_py() == [1, 2]
714+
715+
716+
def test_map_extract():
717+
ctx = SessionContext()
718+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
719+
df = ctx.create_dataframe([[batch]])
720+
721+
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
722+
result = (
723+
df.select(f.map_extract(m, literal("x")).alias("val")).collect()[0].column(0)
724+
)
725+
assert result[0].as_py() == [1]
726+
727+
728+
def test_map_extract_missing_key():
729+
ctx = SessionContext()
730+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
731+
df = ctx.create_dataframe([[batch]])
732+
733+
m = f.make_map(literal("x"), literal(1))
734+
result = (
735+
df.select(f.map_extract(m, literal("z")).alias("val")).collect()[0].column(0)
736+
)
737+
assert result[0].as_py() == [None]
738+
739+
740+
def test_map_entries():
741+
ctx = SessionContext()
742+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
743+
df = ctx.create_dataframe([[batch]])
744+
745+
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
746+
result = df.select(f.map_entries(m).alias("entries")).collect()[0].column(0)
747+
assert result[0].as_py() == [
748+
{"key": "x", "value": 1},
749+
{"key": "y", "value": 2},
750+
]
751+
752+
753+
def test_element_at():
754+
ctx = SessionContext()
755+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
756+
df = ctx.create_dataframe([[batch]])
757+
758+
m = f.make_map(literal("a"), literal(10), literal("b"), literal(20))
759+
result = (
760+
df.select(f.element_at(m, literal("b")).alias("val")).collect()[0].column(0)
761+
)
762+
assert result[0].as_py() == [20]
763+
764+
765+
def test_map_functions_with_column_data():
766+
ctx = SessionContext()
767+
batch = pa.RecordBatch.from_arrays(
768+
[
769+
pa.array(["k1", "k2", "k3"]),
770+
pa.array([10, 20, 30]),
771+
],
772+
names=["keys", "vals"],
773+
)
774+
df = ctx.create_dataframe([[batch]])
775+
776+
m = f.make_map(column("keys"), column("vals"))
777+
result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0)
778+
for i, expected in enumerate(["k1", "k2", "k3"]):
779+
assert result[i].as_py() == [expected]
780+
781+
result = df.select(f.map_values(m).alias("v")).collect()[0].column(0)
782+
for i, expected in enumerate([10, 20, 30]):
783+
assert result[i].as_py() == [expected]
784+
785+
671786
@pytest.mark.parametrize(
672787
("function", "expected_result"),
673788
[

0 commit comments

Comments
 (0)