3737 WindowFrame ,
3838 column ,
3939 literal ,
40+ udf ,
4041)
4142from datafusion import (
4243 col as df_col ,
@@ -3190,6 +3191,13 @@ def test_fill_null_all_null_column(ctx):
31903191 assert result .column (1 ).to_pylist () == ["filled" , "filled" , "filled" ]
31913192
31923193
3194+ @udf ([pa .int64 ()], pa .int64 (), "immutable" )
3195+ def slow_udf (x : pa .Array ) -> pa .Array :
3196+ # This must be longer than the check interval in wait_for_future
3197+ time .sleep (2.0 )
3198+ return x
3199+
3200+
31933201def test_collect_interrupted ():
31943202 """Test that a long-running query can be interrupted with Ctrl-C.
31953203
@@ -3198,50 +3206,7 @@ def test_collect_interrupted():
31983206 """
31993207 # Create a context and a DataFrame with a query that will run for a while
32003208 ctx = SessionContext ()
3201-
3202- # Create a recursive computation that will run for some time
3203- batches = []
3204- for i in range (10 ):
3205- batch = pa .RecordBatch .from_arrays (
3206- [
3207- pa .array (list (range (i * 1000 , (i + 1 ) * 1000 ))),
3208- pa .array ([f"value_{ j } " for j in range (i * 1000 , (i + 1 ) * 1000 )]),
3209- ],
3210- names = ["a" , "b" ],
3211- )
3212- batches .append (batch )
3213-
3214- # Register tables
3215- ctx .register_record_batches ("t1" , [batches ])
3216- ctx .register_record_batches ("t2" , [batches ])
3217-
3218- # Create a large join operation that will take time to process
3219- df = ctx .sql ("""
3220- WITH t1_expanded AS (
3221- SELECT
3222- a,
3223- b,
3224- CAST(a AS DOUBLE) / 1.5 AS c,
3225- CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
3226- FROM t1
3227- CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
3228- ),
3229- t2_expanded AS (
3230- SELECT
3231- a,
3232- b,
3233- CAST(a AS DOUBLE) * 2.5 AS e,
3234- CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
3235- FROM t2
3236- CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
3237- )
3238- SELECT
3239- t1.a, t1.b, t1.c, t1.d,
3240- t2.a AS a2, t2.b AS b2, t2.e, t2.f
3241- FROM t1_expanded t1
3242- JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
3243- WHERE t1.a > 100 AND t2.a > 100
3244- """ )
3209+ df = ctx .from_pydict ({"a" : [1 , 2 , 3 ]}).select (slow_udf (column ("a" )))
32453210
32463211 # Flag to track if the query was interrupted
32473212 interrupted = False
@@ -3298,7 +3263,10 @@ def trigger_interrupt():
32983263 except KeyboardInterrupt :
32993264 interrupted = True
33003265 except Exception as e :
3301- interrupt_error = e
3266+ if "KeyboardInterrupt" in str (e ):
3267+ interrupted = True
3268+ else :
3269+ interrupt_error = e
33023270
33033271 # Assert that the query was interrupted properly
33043272 if not interrupted :
@@ -3308,7 +3276,7 @@ def trigger_interrupt():
33083276 interrupt_thread .join (timeout = 1.0 )
33093277
33103278
3311- def test_arrow_c_stream_interrupted (): # noqa: C901 PLR0915
3279+ def test_arrow_c_stream_interrupted (): # noqa: C901
33123280 """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
33133281
33143282 Similar to ``test_collect_interrupted`` this test issues a long running
@@ -3318,49 +3286,7 @@ def test_arrow_c_stream_interrupted(): # noqa: C901 PLR0915
33183286 """
33193287
33203288 ctx = SessionContext ()
3321-
3322- batches = []
3323- for i in range (10 ):
3324- batch = pa .RecordBatch .from_arrays (
3325- [
3326- pa .array (list (range (i * 1000 , (i + 1 ) * 1000 ))),
3327- pa .array ([f"value_{ j } " for j in range (i * 1000 , (i + 1 ) * 1000 )]),
3328- ],
3329- names = ["a" , "b" ],
3330- )
3331- batches .append (batch )
3332-
3333- ctx .register_record_batches ("t1" , [batches ])
3334- ctx .register_record_batches ("t2" , [batches ])
3335-
3336- df = ctx .sql (
3337- """
3338- WITH t1_expanded AS (
3339- SELECT
3340- a,
3341- b,
3342- CAST(a AS DOUBLE) / 1.5 AS c,
3343- CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
3344- FROM t1
3345- CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
3346- ),
3347- t2_expanded AS (
3348- SELECT
3349- a,
3350- b,
3351- CAST(a AS DOUBLE) * 2.5 AS e,
3352- CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
3353- FROM t2
3354- CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
3355- )
3356- SELECT
3357- t1.a, t1.b, t1.c, t1.d,
3358- t2.a AS a2, t2.b AS b2, t2.e, t2.f
3359- FROM t1_expanded t1
3360- JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
3361- WHERE t1.a > 100 AND t2.a > 100
3362- """
3363- )
3289+ df = ctx .from_pydict ({"a" : [1 , 2 , 3 ]}).select (slow_udf (column ("a" )))
33643290
33653291 reader = pa .RecordBatchReader .from_stream (df )
33663292
0 commit comments