1818import pyarrow as pa
1919import pytest
2020
21- import bigframes .pandas as bpd
22-
2321from ...utils import assert_series_equal
2422
2523
3230 pytest .param (slice (0 , 2 , None ), id = "default_step_slice" ),
3331 ],
3432)
35- def test_getitem (key ):
33+ @pytest .mark .parametrize (
34+ ("column_name" , "dtype" ),
35+ [
36+ pytest .param ("int_list_col" , pd .ArrowDtype (pa .list_ (pa .int64 ()))),
37+ pytest .param ("bool_list_col" , pd .ArrowDtype (pa .list_ (pa .bool_ ()))),
38+ pytest .param ("float_list_col" , pd .ArrowDtype (pa .list_ (pa .float64 ()))),
39+ pytest .param ("date_list_col" , pd .ArrowDtype (pa .list_ (pa .date32 ()))),
40+ pytest .param ("date_time_list_col" , pd .ArrowDtype (pa .list_ (pa .timestamp ("us" )))),
41+ pytest .param ("numeric_list_col" , pd .ArrowDtype (pa .list_ (pa .decimal128 (38 , 9 )))),
42+ pytest .param ("string_list_col" , pd .ArrowDtype (pa .list_ (pa .string ()))),
43+ ],
44+ )
45+ def test_getitem (key , column_name , dtype , repeated_df , repeated_pandas_df ):
3646 if packaging .version .Version (pd .__version__ ) < packaging .version .Version ("2.2.0" ):
3747 pytest .skip (
3848 "https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#series-list-accessor-for-pyarrow-list-data"
3949 )
40- data = [[1 ], [2 , 3 ], [4 , 5 , 6 ]]
41- s = bpd .Series (data , dtype = pd .ArrowDtype (pa .list_ (pa .int64 ())))
42- pd_s = pd .Series (data , dtype = pd .ArrowDtype (pa .list_ (pa .int64 ())))
4350
44- bf_result = s .list [key ].to_pandas ()
45- pd_result = pd_s .list [key ]
51+ bf_result = repeated_df [ column_name ] .list [key ].to_pandas ()
52+ pd_result = repeated_pandas_df [ column_name ]. astype ( dtype ) .list [key ]
4653
47- assert_series_equal (pd_result , bf_result , check_dtype = False , check_index_type = False )
54+ assert_series_equal (
55+ pd_result ,
56+ bf_result ,
57+ check_dtype = False ,
58+ check_index_type = False ,
59+ check_names = False ,
60+ )
4861
4962
5063@pytest .mark .parametrize (
@@ -60,24 +73,36 @@ def test_getitem(key):
6073 (slice (0 , 2 , 2 ), pytest .raises (NotImplementedError )),
6174 ],
6275)
63- def test_getitem_notsupported (key , expectation ):
64- data = [[1 ], [2 , 3 ], [4 , 5 , 6 ]]
65- s = bpd .Series (data , dtype = pd .ArrowDtype (pa .list_ (pa .int64 ())))
66-
76+ def test_getitem_notsupported (key , expectation , repeated_df ):
6777 with expectation as e :
68- assert s .list [key ] == e
78+ assert repeated_df [ "int_list_col" ] .list [key ] == e
6979
7080
71- def test_len ():
81+ @pytest .mark .parametrize (
82+ ("column_name" , "dtype" ),
83+ [
84+ pytest .param ("int_list_col" , pd .ArrowDtype (pa .list_ (pa .int64 ()))),
85+ pytest .param ("bool_list_col" , pd .ArrowDtype (pa .list_ (pa .bool_ ()))),
86+ pytest .param ("float_list_col" , pd .ArrowDtype (pa .list_ (pa .float64 ()))),
87+ pytest .param ("date_list_col" , pd .ArrowDtype (pa .list_ (pa .date32 ()))),
88+ pytest .param ("date_time_list_col" , pd .ArrowDtype (pa .list_ (pa .timestamp ("us" )))),
89+ pytest .param ("numeric_list_col" , pd .ArrowDtype (pa .list_ (pa .decimal128 (38 , 9 )))),
90+ pytest .param ("string_list_col" , pd .ArrowDtype (pa .list_ (pa .string ()))),
91+ ],
92+ )
93+ def test_len (column_name , dtype , repeated_df , repeated_pandas_df ):
7294 if packaging .version .Version (pd .__version__ ) < packaging .version .Version ("2.2.0" ):
7395 pytest .skip (
7496 "https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#series-list-accessor-for-pyarrow-list-data"
7597 )
76- data = [[], [1 ], [1 , 2 ], [1 , 2 , 3 ]]
77- s = bpd .Series (data , dtype = pd .ArrowDtype (pa .list_ (pa .int64 ())))
78- pd_s = pd .Series (data , dtype = pd .ArrowDtype (pa .list_ (pa .int64 ())))
7998
80- bf_result = s .list .len ().to_pandas ()
81- pd_result = pd_s .list .len ()
99+ bf_result = repeated_df [ column_name ] .list .len ().to_pandas ()
100+ pd_result = repeated_pandas_df [ column_name ]. astype ( dtype ) .list .len ()
82101
83- assert_series_equal (pd_result , bf_result , check_dtype = False , check_index_type = False )
102+ assert_series_equal (
103+ pd_result ,
104+ bf_result ,
105+ check_dtype = False ,
106+ check_index_type = False ,
107+ check_names = False ,
108+ )
0 commit comments