Skip to content

Commit b56627d

Browse files
committed
Use an explicit wait in a dataframe query during testing to check for keyboard interrupts
1 parent a922967 commit b56627d

File tree

1 file changed

+15
-89
lines changed

1 file changed

+15
-89
lines changed

python/tests/test_dataframe.py

Lines changed: 15 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
WindowFrame,
3838
column,
3939
literal,
40+
udf,
4041
)
4142
from datafusion import (
4243
col as df_col,
@@ -3190,6 +3191,13 @@ def test_fill_null_all_null_column(ctx):
31903191
assert result.column(1).to_pylist() == ["filled", "filled", "filled"]
31913192

31923193

3194+
@udf([pa.int64()], pa.int64(), "immutable")
3195+
def slow_udf(x: pa.Array) -> pa.Array:
3196+
# This must be longer than the check interval in wait_for_future
3197+
time.sleep(2.0)
3198+
return x
3199+
3200+
31933201
def test_collect_interrupted():
31943202
"""Test that a long-running query can be interrupted with Ctrl-C.
31953203
@@ -3198,50 +3206,7 @@ def test_collect_interrupted():
31983206
"""
31993207
# Create a context and a DataFrame with a query that will run for a while
32003208
ctx = SessionContext()
3201-
3202-
# Create a recursive computation that will run for some time
3203-
batches = []
3204-
for i in range(10):
3205-
batch = pa.RecordBatch.from_arrays(
3206-
[
3207-
pa.array(list(range(i * 1000, (i + 1) * 1000))),
3208-
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
3209-
],
3210-
names=["a", "b"],
3211-
)
3212-
batches.append(batch)
3213-
3214-
# Register tables
3215-
ctx.register_record_batches("t1", [batches])
3216-
ctx.register_record_batches("t2", [batches])
3217-
3218-
# Create a large join operation that will take time to process
3219-
df = ctx.sql("""
3220-
WITH t1_expanded AS (
3221-
SELECT
3222-
a,
3223-
b,
3224-
CAST(a AS DOUBLE) / 1.5 AS c,
3225-
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
3226-
FROM t1
3227-
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
3228-
),
3229-
t2_expanded AS (
3230-
SELECT
3231-
a,
3232-
b,
3233-
CAST(a AS DOUBLE) * 2.5 AS e,
3234-
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
3235-
FROM t2
3236-
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
3237-
)
3238-
SELECT
3239-
t1.a, t1.b, t1.c, t1.d,
3240-
t2.a AS a2, t2.b AS b2, t2.e, t2.f
3241-
FROM t1_expanded t1
3242-
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
3243-
WHERE t1.a > 100 AND t2.a > 100
3244-
""")
3209+
df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a")))
32453210

32463211
# Flag to track if the query was interrupted
32473212
interrupted = False
@@ -3298,7 +3263,10 @@ def trigger_interrupt():
32983263
except KeyboardInterrupt:
32993264
interrupted = True
33003265
except Exception as e:
3301-
interrupt_error = e
3266+
if "KeyboardInterrupt" in str(e):
3267+
interrupted = True
3268+
else:
3269+
interrupt_error = e
33023270

33033271
# Assert that the query was interrupted properly
33043272
if not interrupted:
@@ -3308,7 +3276,7 @@ def trigger_interrupt():
33083276
interrupt_thread.join(timeout=1.0)
33093277

33103278

3311-
def test_arrow_c_stream_interrupted(): # noqa: C901 PLR0915
3279+
def test_arrow_c_stream_interrupted(): # noqa: C901
33123280
"""__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
33133281
33143282
Similar to ``test_collect_interrupted`` this test issues a long running
@@ -3318,49 +3286,7 @@ def test_arrow_c_stream_interrupted(): # noqa: C901 PLR0915
33183286
"""
33193287

33203288
ctx = SessionContext()
3321-
3322-
batches = []
3323-
for i in range(10):
3324-
batch = pa.RecordBatch.from_arrays(
3325-
[
3326-
pa.array(list(range(i * 1000, (i + 1) * 1000))),
3327-
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
3328-
],
3329-
names=["a", "b"],
3330-
)
3331-
batches.append(batch)
3332-
3333-
ctx.register_record_batches("t1", [batches])
3334-
ctx.register_record_batches("t2", [batches])
3335-
3336-
df = ctx.sql(
3337-
"""
3338-
WITH t1_expanded AS (
3339-
SELECT
3340-
a,
3341-
b,
3342-
CAST(a AS DOUBLE) / 1.5 AS c,
3343-
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
3344-
FROM t1
3345-
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
3346-
),
3347-
t2_expanded AS (
3348-
SELECT
3349-
a,
3350-
b,
3351-
CAST(a AS DOUBLE) * 2.5 AS e,
3352-
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
3353-
FROM t2
3354-
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
3355-
)
3356-
SELECT
3357-
t1.a, t1.b, t1.c, t1.d,
3358-
t2.a AS a2, t2.b AS b2, t2.e, t2.f
3359-
FROM t1_expanded t1
3360-
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
3361-
WHERE t1.a > 100 AND t2.a > 100
3362-
"""
3363-
)
3289+
df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a")))
33643290

33653291
reader = pa.RecordBatchReader.from_stream(df)
33663292

0 commit comments

Comments
 (0)