Skip to content
42 changes: 32 additions & 10 deletions udf/worker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ WorkerDispatcher -- manages workers, creates sessions
|
v
WorkerSession -- one UDF execution
| 1. session.init(InitMessage(payload, inputSchema, outputSchema))
| 1. session.init(Init proto)
| 2. val results = session.process(inputBatches)
| 3. session.close()
```
Expand All @@ -34,12 +34,13 @@ provisioning service or daemon).
```
udf/worker/
├── proto/
│ worker_spec.proto -- UDFWorkerSpecification protobuf (+ generated Java classes)
│ worker_spec.proto -- UDFWorkerSpecification protobuf
│ udf_protocol.proto -- UDF execution protocol (Init, UdfPayload, ...)
│ common.proto -- shared enums (UDFWorkerDataFormat, etc.)
└── core/ -- abstract interfaces
WorkerDispatcher.scala -- creates sessions, manages worker lifecycle
WorkerSession.scala -- per-UDF init/process/cancel/close + InitMessage
WorkerSession.scala -- per-UDF init/process/cancel/close
WorkerConnection.scala -- transport channel abstraction
WorkerSecurityScope.scala -- security boundary for worker pooling
Expand All @@ -55,6 +56,19 @@ worker creation where Spark spawns local OS processes. Future packages
(e.g., `core/indirect/`) can implement alternative creation modes such as
obtaining workers from a provisioning service or daemon.

## Wire protocol

Each UDF execution uses a single bidirectional `Execute` gRPC stream.

```
Engine -> Worker: Init -> PayloadChunk* -> (DataRequest)* -> Finish (Cancel)?
| Cancel
Worker -> Engine: InitResponse -> (DataResponse)* -> (ExecutionError)? -> (FinishResponse | CancelResponse)
```

See `udf/worker/proto/src/main/protobuf/udf_protocol.proto` for the complete
protocol definition, ordering invariants, and error contract.

### Direct worker creation

`DirectWorkerDispatcher` spawns worker processes locally. On the first
Expand All @@ -76,10 +90,12 @@ Workers are terminated via SIGTERM/SIGKILL when the dispatcher is closed.

```scala
import org.apache.spark.udf.worker.{
DirectWorker, ProcessCallable, UDFProtoCommunicationPattern,
UDFWorkerDataFormat, UDFWorkerProperties, UDFWorkerSpecification,
UnixDomainSocket, WorkerCapabilities, WorkerConnectionSpec, WorkerEnvironment}
DirectWorker, Init, ProcessCallable, UdfPayload,
UDFProtoCommunicationPattern, UDFWorkerDataFormat, UDFWorkerProperties,
UDFWorkerSpecification, UnixDomainSocket, WorkerCapabilities,
WorkerConnectionSpec, WorkerEnvironment}
import org.apache.spark.udf.worker.core._
import com.google.protobuf.ByteString

// 1. Define a worker spec (direct creation mode).
val spec = UDFWorkerSpecification.newBuilder()
Expand Down Expand Up @@ -112,10 +128,16 @@ val dispatcher: WorkerDispatcher = ...
val session = dispatcher.createSession(securityScope = None)
try {
// 4. Initialize with the serialized function and schemas.
session.init(InitMessage(
functionPayload = serializedFunction,
inputSchema = arrowInputSchema,
outputSchema = arrowOutputSchema))
session.init(Init.newBuilder()
.setProtocolVersion(1)
.setUdf(UdfPayload.newBuilder()
.setPayload(ByteString.copyFrom(serializedFunction))
.setFormat(payloadFormat) // worker-recognised tag
.build())
.setDataFormat(UDFWorkerDataFormat.ARROW)
.setInputSchema(ByteString.copyFrom(arrowInputSchema))
.setOutputSchema(ByteString.copyFrom(arrowOutputSchema))
.build())

// 5. Process data -- Iterator in, Iterator out.
val results: Iterator[Array[Byte]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ import org.apache.spark.udf.worker.UDFWorkerSpecification
* as security scope). It owns the underlying worker processes and connections,
* handling pooling, reuse, and lifecycle behind the scenes. Spark interacts with
* workers exclusively through the [[WorkerSession]]s returned by [[createSession]].
*
* '''Worker invalidation:''' if a session terminates with a transport error the
* worker that backed it MUST NOT be returned to any reuse pool. A transport
* error leaves the worker in an unknown state; only workers that complete
* sessions cleanly are eligible for reuse. Implementations are responsible for
* tracking this condition -- typically [[WorkerSession.doProcess]] flags the
* worker as invalid before [[WorkerSession.doClose]] releases it, so the
* dispatcher can distinguish a clean release from a failed one.
*/
@Experimental
trait WorkerDispatcher extends AutoCloseable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,7 @@ package org.apache.spark.udf.worker.core
import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.annotation.Experimental

/**
* :: Experimental ::
* Carries all information needed to initialize a UDF execution on a worker.
*
* This message is passed to [[WorkerSession#init]] and contains the function
* definition, schemas, and any additional configuration.
*
* Placeholder: will be replaced by a generated proto message once the
* UDF wire protocol lands. Do not rely on case-class equality --
* `Array[Byte]` fields compare by reference.
*
* @param functionPayload serialized function (e.g., pickled Python, JVM bytes)
* @param inputSchema serialized input schema (e.g., Arrow schema bytes)
* @param outputSchema serialized output schema (e.g., Arrow schema bytes)
* @param properties additional key-value configuration. Can carry
* protocol-specific or engine-specific metadata that
* does not yet have a dedicated field.
*/
@Experimental
case class InitMessage(
functionPayload: Array[Byte],
inputSchema: Array[Byte],
outputSchema: Array[Byte],
properties: Map[String, String] = Map.empty)
import org.apache.spark.udf.worker.Init

/**
* :: Experimental ::
Expand All @@ -62,7 +38,11 @@ case class InitMessage(
* {{{
* val session = dispatcher.createSession(securityScope = None)
* try {
* session.init(InitMessage(functionPayload, inputSchema, outputSchema))
* session.init(Init.newBuilder()
* .setProtocolVersion(1)
* .setUdf(UdfPayload.newBuilder().setPayload(callable).setFormat(fmt).build())
* .setDataFormat(UDFWorkerDataFormat.ARROW)
* .build())
* val results = session.process(inputBatches)
* results.foreach(handleBatch)
* } finally {
Expand All @@ -74,7 +54,8 @@ case class InitMessage(
* - [[init]] must be called exactly once before [[process]].
* - [[process]] must be called at most once per session.
* - [[close]] must always be called (use try-finally).
* - [[cancel]] may be called at any time to abort execution.
* - [[cancel]] may be called at any time from any execution context.
* See [[cancel]] for the full contract.
*
* The lifecycle is enforced here: [[init]] and [[process]] are `final`
* and delegate to [[doInit]] / [[doProcess]] after AtomicBoolean guards.
Expand All @@ -93,10 +74,11 @@ abstract class WorkerSession extends AutoCloseable {
*
* Throws `IllegalStateException` if called more than once.
*
* @param message the initialization parameters including the serialized
* function, input/output schemas, and configuration.
* @param message the [[Init]] message carrying the UDF body, data
* format, optional schemas, and any session context
* the worker needs to start processing.
*/
final def init(message: InitMessage): Unit = {
final def init(message: Init): Unit = {
if (!initialized.compareAndSet(false, true)) {
throw new IllegalStateException("init has already been called on this session")
}
Expand All @@ -108,7 +90,7 @@ abstract class WorkerSession extends AutoCloseable {
*
* Follows Spark's Iterator-to-Iterator pattern: input batches are streamed
* to the worker, and result batches are lazily pulled from the returned
* iterator. The session sends a Finish signal to the worker when the input
* iterator. The session sends a finish signal to the worker when the input
* iterator is exhausted.
*
* Must be called after [[init]] and at most once per session.
Expand All @@ -127,22 +109,64 @@ abstract class WorkerSession extends AutoCloseable {
doProcess(input)
}

/** Subclass hook for [[init]]. Called once, after the guard. */
protected def doInit(message: InitMessage): Unit
/**
* Subclass hook for [[init]]. Called once, after the guard.
* The session MUST NOT be activated before this call, since
* [[cancel]] before [[init]] is contractually a no-op.
*/
protected def doInit(message: Init): Unit

/** Subclass hook for [[process]]. Called at most once, after the guard. */
protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]]

/**
* Requests cancellation of the current UDF execution.
*
* '''Thread-safety:''' implementations must allow [[cancel]] to be called
* from a thread different from the one driving [[process]] (typically a
* task interruption thread). It may be invoked at any point after
* [[init]] and should be a no-op if execution has already finished.
* '''Thread-safety:''' [[cancel]] may be called concurrently with
* [[process]] from any execution context.
*
* '''Lifecycle:''' [[cancel]] is idempotent and safe at any point in
* the session's life:
* - before [[init]] -- a no-op; the session may still be closed
* normally via [[close]].
* - between [[init]] and [[process]] -- signals that the session
* should be terminated; the caller should not invoke [[process]]
* and should call [[close]] to release resources.
* Implementations SHOULD surface this as an error if [[process]]
* is subsequently invoked despite the cancellation.
* - during [[process]] (data flowing or awaiting completion)
* -- requests the worker to abort on a best-effort basis.
* - after [[process]] has returned (session already terminated)
* -- a no-op.
*
* Implementations are responsible for the lifecycle-aware behavior
* described above so that the caller does not need to coordinate
* with the execution context driving [[process]].
*/
def cancel(): Unit

/** Closes this session and releases resources. */
override def close(): Unit
/**
* Closes this session and releases resources. Idempotent; safe to
* call from a `finally` block regardless of whether [[init]],
* [[process]], or [[cancel]] have been invoked.
*
* If [[init]] was called but [[process]] was not (e.g. an exception
* was thrown between the two), [[close]] signals cancellation to the
* worker before releasing resources so it can clean up
* deterministically. Subclasses implement [[doClose]] for resource
* teardown; the base class handles the cancel-before-close guarantee
* automatically.
*/
final override def close(): Unit = {
if (initialized.get() && !processed.get()) {
cancel()
}
doClose()
}

/** Subclass hook for [[close]]. The base class guarantees that
* [[cancel]] has already been called if [[init]] was invoked but
* [[process]] was not.
*/
protected def doClose(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ abstract class DirectWorkerDispatcher(
"DirectWorker.runner must have at least one entry in command or arguments")
val workerId = UUID.randomUUID().toString
val address = newEndpointAddress(workerId)
// Proto contract: the engine must pass --id and --connection.
// The engine injects --connection (the socket address it manages) and
// --id (an internal correlation identifier) into the worker command.
val cmd = baseCmd ++ Seq("--id", workerId, "--connection", address)
val env = runner.getEnvironmentVariablesMap.asScala.toMap
val outputFile = Files.createTempFile("udf-worker-", ".log")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerSession}
*
* This is the session type returned by [[DirectWorkerDispatcher]]. It ties
* the session lifecycle to the worker's ref-count: the dispatcher increments
* the count before construction, and [[close]] decrements it, so the
* the count before construction, and [[doClose]] decrements it, so the
* dispatcher knows when a worker process is idle and can be terminated or
* reused.
*
Expand All @@ -48,7 +48,7 @@ abstract class DirectWorkerSession(
/** The connection to the worker for this session. */
def connection: WorkerConnection = workerProcess.connection

override def close(): Unit = {
override protected def doClose(): Unit = {
if (released.compareAndSet(false, true)) {
workerProcess.releaseSession()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.udf.worker.{
DirectWorker, LocalTcpConnection, ProcessCallable, UDFWorkerProperties,
DirectWorker, Init, LocalTcpConnection, ProcessCallable, UDFWorkerProperties,
UDFWorkerSpecification, UnixDomainSocket, WorkerConnectionSpec,
WorkerEnvironment}
import org.apache.spark.udf.worker.core.direct.{DirectUnixSocketWorkerDispatcher,
Expand All @@ -51,14 +51,14 @@ class SocketFileConnection(socketPath: String)
* TODO: [[cancel]] is a no-op here. Once a concrete [[DirectWorkerSession]]
* with real data-plane wiring lands, add tests exercising cancel() in
* particular: cancel from a different thread than process(), cancel
* after process() has returned, and cancel before init (should be a
* no-op). Tracking the thread-safety contract in the docstring on
* after process() has returned, and cancel before init (should be a no-op).
* See the thread-safety contract in the docstring on
* [[org.apache.spark.udf.worker.core.WorkerSession.cancel]].
*/
class StubWorkerSession(
workerProcess: DirectWorkerProcess) extends DirectWorkerSession(workerProcess) {

override protected def doInit(message: InitMessage): Unit = {}
override protected def doInit(message: Init): Unit = {}

override protected def doProcess(
input: Iterator[Array[Byte]]): Iterator[Array[Byte]] =
Expand Down
Loading