@@ -668,6 +668,121 @@ def test_array_function_obj_tests(stmt, py_expr):
668668 assert a == b
669669
670670
671+ def test_make_map ():
672+ ctx = SessionContext ()
673+ batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
674+ df = ctx .create_dataframe ([[batch ]])
675+
676+ result = (
677+ df .select (
678+ f .make_map (
679+ literal ("x" ),
680+ literal (1 ),
681+ literal ("y" ),
682+ literal (2 ),
683+ ).alias ("map" )
684+ )
685+ .collect ()[0 ]
686+ .column (0 )
687+ )
688+ assert result [0 ].as_py () == [("x" , 1 ), ("y" , 2 )]
689+
690+
691+ def test_make_map_odd_args ():
692+ with pytest .raises (ValueError , match = "even number of arguments" ):
693+ f .make_map (literal ("x" ), literal (1 ), literal ("y" ))
694+
695+
696+ def test_map_keys ():
697+ ctx = SessionContext ()
698+ batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
699+ df = ctx .create_dataframe ([[batch ]])
700+
701+ m = f .make_map (literal ("x" ), literal (1 ), literal ("y" ), literal (2 ))
702+ result = df .select (f .map_keys (m ).alias ("keys" )).collect ()[0 ].column (0 )
703+ assert result [0 ].as_py () == ["x" , "y" ]
704+
705+
706+ def test_map_values ():
707+ ctx = SessionContext ()
708+ batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
709+ df = ctx .create_dataframe ([[batch ]])
710+
711+ m = f .make_map (literal ("x" ), literal (1 ), literal ("y" ), literal (2 ))
712+ result = df .select (f .map_values (m ).alias ("vals" )).collect ()[0 ].column (0 )
713+ assert result [0 ].as_py () == [1 , 2 ]
714+
715+
716+ def test_map_extract ():
717+ ctx = SessionContext ()
718+ batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
719+ df = ctx .create_dataframe ([[batch ]])
720+
721+ m = f .make_map (literal ("x" ), literal (1 ), literal ("y" ), literal (2 ))
722+ result = (
723+ df .select (f .map_extract (m , literal ("x" )).alias ("val" )).collect ()[0 ].column (0 )
724+ )
725+ assert result [0 ].as_py () == [1 ]
726+
727+
728+ def test_map_extract_missing_key ():
729+ ctx = SessionContext ()
730+ batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
731+ df = ctx .create_dataframe ([[batch ]])
732+
733+ m = f .make_map (literal ("x" ), literal (1 ))
734+ result = (
735+ df .select (f .map_extract (m , literal ("z" )).alias ("val" )).collect ()[0 ].column (0 )
736+ )
737+ assert result [0 ].as_py () == [None ]
738+
739+
740+ def test_map_entries ():
741+ ctx = SessionContext ()
742+ batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
743+ df = ctx .create_dataframe ([[batch ]])
744+
745+ m = f .make_map (literal ("x" ), literal (1 ), literal ("y" ), literal (2 ))
746+ result = df .select (f .map_entries (m ).alias ("entries" )).collect ()[0 ].column (0 )
747+ assert result [0 ].as_py () == [
748+ {"key" : "x" , "value" : 1 },
749+ {"key" : "y" , "value" : 2 },
750+ ]
751+
752+
753+ def test_element_at ():
754+ ctx = SessionContext ()
755+ batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
756+ df = ctx .create_dataframe ([[batch ]])
757+
758+ m = f .make_map (literal ("a" ), literal (10 ), literal ("b" ), literal (20 ))
759+ result = (
760+ df .select (f .element_at (m , literal ("b" )).alias ("val" )).collect ()[0 ].column (0 )
761+ )
762+ assert result [0 ].as_py () == [20 ]
763+
764+
765+ def test_map_functions_with_column_data ():
766+ ctx = SessionContext ()
767+ batch = pa .RecordBatch .from_arrays (
768+ [
769+ pa .array (["k1" , "k2" , "k3" ]),
770+ pa .array ([10 , 20 , 30 ]),
771+ ],
772+ names = ["keys" , "vals" ],
773+ )
774+ df = ctx .create_dataframe ([[batch ]])
775+
776+ m = f .make_map (column ("keys" ), column ("vals" ))
777+ result = df .select (f .map_keys (m ).alias ("k" )).collect ()[0 ].column (0 )
778+ for i , expected in enumerate (["k1" , "k2" , "k3" ]):
779+ assert result [i ].as_py () == [expected ]
780+
781+ result = df .select (f .map_values (m ).alias ("v" )).collect ()[0 ].column (0 )
782+ for i , expected in enumerate ([10 , 20 , 30 ]):
783+ assert result [i ].as_py () == [expected ]
784+
785+
671786@pytest .mark .parametrize (
672787 ("function" , "expected_result" ),
673788 [
0 commit comments