Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.sql.connect.service

import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.util.concurrent.{TimeoutException, TimeUnit}

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.sys.process.Process
import scala.util.Random
import scala.util.control.NonFatal

import com.google.common.collect.Lists
import org.scalatest.time.SpanSugar._
Expand All @@ -37,8 +39,10 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, SparkConnectPlanner, StreamingForeachBatchHelper}
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper.RunnerCleaner
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
import org.apache.spark.sql.pipelines.logging.PipelineEvent
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -228,15 +232,99 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
}
}

test("python foreachBatch process: process terminates after query is stopped") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
assume(PythonTestDepsChecker.isConnectDepsAvailable)
// scalastyle:on assume
// Same semantics as SparkFunSuite.retry, but prints the retry events to stdout so they
// appear in the GitHub Actions job log. SparkFunSuite.retry uses log4j, which in our test
// setup only writes to target/unit-tests.log (surfaced as an artifact, not in the live log).
// TODO(SPARK-56586): consolidate with SparkFunSuite.retry once that helper supports
// console-visible retry notices.
private def retryWithVisibleLog(maxAttempts: Int)(body: => Unit): Unit = {
var attempt = 1
var done = false
while (!done) {
try {
body
done = true
} catch {
case NonFatal(t) if attempt >= maxAttempts => throw t
case NonFatal(t) =>
// scalastyle:off println
println(
s"===== Attempt $attempt/$maxAttempts failed " +
s"(${t.getClass.getSimpleName}: ${t.getMessage}); retrying =====")
// scalastyle:on println
afterEach()
beforeEach()
attempt += 1
}
}
}

private def awaitTestBodyInNewThread(timeoutMillis: Long, onTimeout: () => Unit)(
body: => Unit): Unit = {
@volatile var error: Throwable = null
val runnable: Runnable = () => {
try {
body
} catch {
case t: Throwable => error = t
}
}
val worker = new Thread(runnable, s"${getClass.getSimpleName}-testBody-worker")
worker.setDaemon(true)
worker.start()
worker.join(timeoutMillis)
if (worker.isAlive) {
// Capture the worker's stack so post-mortem diagnostics can identify which leaked
// thread belongs to which attempt without a separate jstack.
// scalastyle:off println
println(
s"===== Test body did not complete within $timeoutMillis ms " +
s"(thread=${worker.getName}, state=${worker.getState}); stack trace follows =====")
worker.getStackTrace.foreach(frame => println(s" at $frame"))
// scalastyle:on println
// Best-effort: release any resource the worker is blocked on so it can unwind its own
// finally and stop holding global state (SparkConnectService, listeners, ...).
onTimeout()
// Also interrupt the worker so any interruptible blocking call (e.g. the Thread.join
// inside StreamExecution.interruptAndAwaitExecutionThreadTermination, which fires when
// spark.sql.streaming.stopTimeout is 0/infinite) wakes up. onTimeout() handles socket
// hangs; this handles the join/wait family.
worker.interrupt()
// Give the now-unblocked worker time to finish its cleanup before we declare defeat.
// 30s covers the body's 4s Thread.sleep plus SparkConnectService.stop().
val gracePeriodMs = 30.seconds.toMillis
worker.join(gracePeriodMs)
val te = new TimeoutException(
s"Test body did not complete within $timeoutMillis ms " +
s"(after a $gracePeriodMs ms post-cleanup grace period)")
// If the body actually finished during the grace window, surface the original failure
// as the cause so a slow assertion failure is not misreported as a pure hang.
if (!worker.isAlive && error != null) te.initCause(error)
throw te
}
if (error != null) throw error
}

private def runPythonForeachBatchTerminationTestBody(sessionHolder: SessionHolder): Unit = {
// Suffix query names so a retry after a timed-out attempt does not collide with a leaked
// query from the previous attempt (the leaked thread can still hold the old query name in
// spark.streams.active).
val suffix = s"_${System.nanoTime()}"
val q1Name = s"foreachBatch_termination_test_q1$suffix"
val q2Name = s"foreachBatch_termination_test_q2$suffix"

// Snapshot listeners before this attempt registers anything so we can scope cleanup and
// assertions to listeners we added -- even if a previous timed-out attempt leaked a worker
// whose own finally is racing with us.
val baselineListeners = spark.streams.listListeners().toSet
var capturedServer: AnyRef = null
var ourNewListeners = Set.empty[StreamingQueryListener]

val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
try {
SparkConnectService.start(spark.sparkContext)
// Identity-check the server in `finally`: a leaked finally from a previous attempt's
// worker thread must not tear down a service belonging to a later attempt.
capturedServer = SparkConnectService.server

val pythonFn = dummyPythonFunction(sessionHolder)(streamingForeachBatchFunction)
val (fn1, cleaner1) =
Expand All @@ -249,7 +337,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
.load()
.writeStream
.format("memory")
.queryName("foreachBatch_termination_test_q1")
.queryName(q1Name)
.foreachBatch(fn1)
.start()

Expand All @@ -258,7 +346,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
.load()
.writeStream
.format("memory")
.queryName("foreachBatch_termination_test_q2")
.queryName(q2Name)
.foreachBatch(fn2)
.start()

Expand All @@ -267,6 +355,10 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
sessionHolder.streamingForeachBatchRunnerCleanerCache
.registerCleanerForQuery(query2, cleaner2)

// The first registerCleanerForQuery lazily registers the cleaner listener. Capture the
// listeners we added so finally only removes ours, not a concurrent attempt's.
ourNewListeners = spark.streams.listListeners().toSet -- baselineListeners

val (runner1, runner2) =
(cleaner1.asInstanceOf[RunnerCleaner].runner, cleaner2.asInstanceOf[RunnerCleaner].runner)

Expand All @@ -288,14 +380,73 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
assert(runner2.isWorkerStopped().get)
}

assert(spark.streams.active.isEmpty) // no running query
assert(spark.streams.listListeners().length == 1) // only process termination listener
// Only check this attempt's queries stopped (a previous timed-out attempt may have
// leaked queries into spark.streams.active that we cannot synchronously clean up).
assert(!spark.streams.active.exists(q => q.name == q1Name || q.name == q2Name))
// Attempt-scoped variant of the original `listListeners().length == 1` assertion:
// exactly one new listener (the cleaner listener) should have been registered by this
// attempt, regardless of any listeners a leaked previous attempt may still hold.
assert(
ourNewListeners.size == 1,
s"expected exactly 1 new listener registered by this attempt, " +
s"got ${ourNewListeners.size}")
} finally {
SparkConnectService.stop()
// Wait for things to calm down.
Thread.sleep(4.seconds.toMillis)
// remove process termination listener
spark.streams.listListeners().foreach(spark.streams.removeListener)
// Only stop the service if it is still the one this attempt started; otherwise a
// leaked finally from a previous attempt would shut down the live service belonging to
// whichever attempt is currently running.
if (capturedServer != null && (SparkConnectService.server eq capturedServer)) {
SparkConnectService.stop()
// Wait for things to calm down.
Thread.sleep(4.seconds.toMillis)
}
// Remove only the listeners this attempt registered; never touch a concurrent
// attempt's process-termination listener.
ourNewListeners.foreach(spark.streams.removeListener)
}
}

test("python foreachBatch process: process terminates after query is stopped") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
assume(PythonTestDepsChecker.isConnectDepsAvailable)
// scalastyle:on assume

// Bound query.stop() so it cannot hang indefinitely on a stuck streaming execution
// thread. The default for spark.sql.streaming.stopTimeout is 0 (wait forever), which
// turns a stuck batch into an unkillable test. 30s is an order of magnitude larger
// than the body's typical green-path runtime (~5s) but small enough that an outer
// attempt-level retry can recover within the 2-minute cap below.
withSQLConf(SQLConf.STREAMING_STOP_TIMEOUT.key -> "30000") {
retryWithVisibleLog(maxAttempts = 3) {
// Create the SessionHolder here (not inside the body) so the wrapper can reach into
// it on timeout. The body runs on a fresh daemon thread so the test thread can move
// on when the body hangs inside a non-interruptible socket read. On timeout, closing
// the cleaner cache closes the Python worker sockets; that unblocks the hung
// dataIn.readInt and the leaked thread can run its own finally (stop
// SparkConnectService, remove listeners) before the next retry starts.
// Outer cap budget: body uses up to 30s + 30s `eventually` plus a 4s sleep plus
// service start/stop overhead; 2 minutes leaves comfortable headroom on slow CI
// runners while still strictly bounding the original 150-minute hang.
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
awaitTestBodyInNewThread(
timeoutMillis = TimeUnit.MINUTES.toMillis(2),
onTimeout = () => {
try sessionHolder.streamingForeachBatchRunnerCleanerCache.cleanUpAll()
catch {
case t: Throwable =>
// Surface suppressed cleanup errors: a failure here is exactly the case
// where the next attempt is most likely to also hang, so silently dropping
// it would make diagnosis impossible.
// scalastyle:off println
println(
s"===== onTimeout cleanUpAll suppressed " +
s"${t.getClass.getSimpleName}: ${t.getMessage} =====")
// scalastyle:on println
}
}) {
runPythonForeachBatchTerminationTestBody(sessionHolder)
}
}
}
}

Expand Down