@@ -1937,6 +1937,100 @@ def float_parser(row):
19371937 )
19381938
19391939
1940+ @pytest .mark .flaky (retries = 2 , delay = 120 )
1941+ def test_df_apply_axis_1_args (session , scalars_dfs ):
1942+ columns = ["int64_col" , "int64_too" ]
1943+ scalars_df , scalars_pandas_df = scalars_dfs
1944+
1945+ try :
1946+
1947+ def the_sum (s1 , s2 , x ):
1948+ return s1 + s2 + x
1949+
1950+ the_sum_mf = session .remote_function (
1951+ input_types = [int , int , int ],
1952+ output_type = int ,
1953+ reuse = False ,
1954+ cloud_function_service_account = "default" ,
1955+ )(the_sum )
1956+
1957+ args1 = (1 ,)
1958+
1959+ # Fails to apply on dataframe with incompatible number of columns.
1960+ with pytest .raises (
1961+ ValueError ,
1962+ match = "^Column count mismatch: BigFrames BigQuery function expected 2 columns from DataFrame but received 3\\ .$" ,
1963+ ):
1964+ scalars_df [columns + ["float64_col" ]].apply (the_sum_mf , axis = 1 , args = args1 )
1965+
1966+ # Fails to apply on dataframe with incompatible column datatypes.
1967+ with pytest .raises (
1968+ ValueError ,
1969+ match = "^Data type mismatch: BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*" ,
1970+ ):
1971+ scalars_df [columns ].assign (
1972+ int64_col = lambda df : df ["int64_col" ].astype ("Float64" )
1973+ ).apply (the_sum_mf , axis = 1 , args = args1 )
1974+
1975+ bf_result = (
1976+ scalars_df [columns ]
1977+ .dropna ()
1978+ .apply (the_sum_mf , axis = 1 , args = args1 )
1979+ .to_pandas ()
1980+ )
1981+ pd_result = scalars_pandas_df [columns ].dropna ().apply (sum , axis = 1 , args = args1 )
1982+
1983+ pandas .testing .assert_series_equal (pd_result , bf_result , check_dtype = False )
1984+
1985+ finally :
1986+ # clean up the gcp assets created for the remote function.
1987+ cleanup_function_assets (the_sum_mf , session .bqclient , ignore_failures = False )
1988+
1989+
1990+ @pytest .mark .flaky (retries = 2 , delay = 120 )
1991+ def test_df_apply_axis_1_series_args (session , scalars_dfs ):
1992+ columns = ["int64_col" , "float64_col" ]
1993+ scalars_df , scalars_pandas_df = scalars_dfs
1994+
1995+ try :
1996+
1997+ @session .remote_function (
1998+ input_types = [bigframes .series .Series , float , str , bool ],
1999+ output_type = list [str ],
2000+ reuse = False ,
2001+ cloud_function_service_account = "default" ,
2002+ )
2003+ def foo_list (x , y0 : float , y1 , y2 ) -> list [str ]:
2004+ return (
2005+ [str (x ["int64_col" ]), str (y0 ), str (y1 ), str (y2 )]
2006+ if y2
2007+ else [str (x ["float64_col" ])]
2008+ )
2009+
2010+ args1 = (12.34 , "hello world" , True )
2011+ bf_result = scalars_df [columns ].apply (foo_list , axis = 1 , args = args1 ).to_pandas ()
2012+ pd_result = scalars_pandas_df [columns ].apply (foo_list , axis = 1 , args = args1 )
2013+
2014+ # Ignore any dtype difference.
2015+ pandas .testing .assert_series_equal (bf_result , pd_result , check_dtype = False )
2016+
2017+ args2 = (43.21 , "xxx3yyy" , False )
2018+ foo_list_ref = session .read_gbq_function (
2019+ foo_list .bigframes_bigquery_function , is_row_processor = True
2020+ )
2021+ bf_result = (
2022+ scalars_df [columns ].apply (foo_list_ref , axis = 1 , args = args2 ).to_pandas ()
2023+ )
2024+ pd_result = scalars_pandas_df [columns ].apply (foo_list , axis = 1 , args = args2 )
2025+
2026+ # Ignore any dtype difference.
2027+ pandas .testing .assert_series_equal (bf_result , pd_result , check_dtype = False )
2028+
2029+ finally :
2030+ # Clean up the gcp assets created for the remote function.
2031+ cleanup_function_assets (foo_list , session .bqclient , ignore_failures = False )
2032+
2033+
19402034@pytest .mark .parametrize (
19412035 ("memory_mib_args" , "expected_memory" ),
19422036 [
0 commit comments