Skip to content

Commit a66cdda

Browse files
authored
fix: Clean shutdown for Sink threaded server using threading.Event (#325)
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent 23bc5d0 commit a66cdda

6 files changed

Lines changed: 456 additions & 278 deletions

File tree

packages/pynumaflow/pynumaflow/shared/server.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import multiprocessing
55
import os
66
import socket
7+
import threading
78
import traceback
89

910
from google.protobuf import any_pb2
@@ -18,6 +19,7 @@
1819
from pynumaflow._constants import (
1920
_LOGGER,
2021
MULTIPROC_MAP_SOCK_ADDR,
22+
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
2123
UDFType,
2224
)
2325
from pynumaflow.exceptions import SocketError
@@ -57,6 +59,7 @@ def sync_server_start(
5759
server_options=None,
5860
server_info: ServerInfo | None = None,
5961
udf_type: str = UDFType.Map,
62+
shutdown_event: threading.Event | None = None,
6063
):
6164
"""
6265
Utility function to start a sync grpc server instance.
@@ -75,6 +78,7 @@ def sync_server_start(
7578
udf_type=udf_type,
7679
server_info_file=server_info_file,
7780
server_info=server_info,
81+
shutdown_event=shutdown_event,
7882
)
7983

8084

@@ -86,10 +90,15 @@ def _run_server(
8690
udf_type: str,
8791
server_info_file: str | None = None,
8892
server_info: ServerInfo | None = None,
93+
shutdown_event: threading.Event | None = None,
8994
) -> None:
9095
"""
9196
Starts the Synchronous server instance on the given UNIX socket
9297
with given max threads. Wait for the server to terminate.
98+
99+
If *shutdown_event* is provided, a background daemon thread will wait
100+
on it and then call ``server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)``
101+
for a cooperative graceful shutdown (no process kill).
93102
"""
94103
server = grpc.server(
95104
ThreadPoolExecutor(
@@ -115,10 +124,21 @@ def _run_server(
115124
server.add_insecure_port(bind_address)
116125
# start the gRPC server
117126
server.start()
127+
118128
# Add the server information to the server info file if provided
119129
if server_info and server_info_file:
120130
info_server_write(server_info=server_info, info_file=server_info_file)
121131

132+
if shutdown_event is not None:
133+
134+
def _watch_for_shutdown():
135+
shutdown_event.wait()
136+
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
137+
server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)
138+
139+
watcher = threading.Thread(target=_watch_for_shutdown, daemon=True)
140+
watcher.start()
141+
122142
_LOGGER.info("GRPC Server listening on: %s %d", bind_address, os.getpid())
123143
server.wait_for_termination()
124144

@@ -243,14 +263,14 @@ def check_instance(instance, callable_type) -> bool:
243263
return False
244264

245265

246-
def get_grpc_status(err: str):
266+
def get_grpc_status(err: str, detail: str | None = None):
247267
"""
248268
Create a grpc status object with the error details.
249269
"""
250270
details = any_pb2.Any()
251271
details.Pack(
252272
error_details_pb2.DebugInfo(
253-
detail="\n".join(traceback.format_stack()),
273+
detail=detail if detail is not None else "\n".join(traceback.format_stack()),
254274
)
255275
)
256276

@@ -295,9 +315,9 @@ def update_context_err(context: NumaflowServicerContext, e: BaseException, err_m
295315
"""
296316
trace = get_exception_traceback_str(e)
297317
_LOGGER.critical(trace)
298-
_LOGGER.critical(e.__str__())
318+
_LOGGER.critical(err_msg)
299319

300-
grpc_status = get_grpc_status(err_msg)
320+
grpc_status = get_grpc_status(err_msg, detail=trace)
301321

302322
context.set_code(grpc.StatusCode.INTERNAL)
303323
context.set_details(err_msg)

packages/pynumaflow/pynumaflow/shared/synciter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
class SyncIterator:
7-
"""A Sync Interator backed by a queue"""
7+
"""A Sync Iterator backed by a queue"""
88

99
__slots__ = "_queue"
1010

@@ -21,3 +21,7 @@ def read_iterator(self):
2121

2222
def put(self, item):
2323
self._queue.put(item)
24+
25+
def close(self):
26+
"""Unblock any thread waiting on read_iterator() by injecting STREAM_EOF."""
27+
self._queue.put(STREAM_EOF)

packages/pynumaflow/pynumaflow/shared/thread_with_return.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ class ThreadWithReturnValue(Thread):
55
"""
66
A custom Thread class that allows the target function to return a value.
77
This class extends the built-in threading.Thread class.
8+
Exceptions raised by the target are captured and re-raised on join().
89
"""
910

1011
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, verbose=None):
@@ -23,32 +24,42 @@ def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, verbo
2324
Thread.__init__(self, group, target, name, args, kwargs)
2425
# Variable to store the return value of the target function
2526
self._return = None
27+
self._exception: BaseException | None = None
2628

2729
def run(self):
2830
"""
2931
Run the thread.
3032
3133
This method is overridden from the Thread class.
3234
It calls the target function and saves the return value.
35+
If the target raises, the exception is captured for re-raising on join().
3336
"""
3437
if self._target is not None:
35-
# Execute target and store the result
36-
self._return = self._target(*self._args, **self._kwargs)
38+
try:
39+
# Execute target and store the result
40+
self._return = self._target(*self._args, **self._kwargs)
41+
except BaseException as exc:
42+
self._exception = exc
3743

38-
def join(self, *args):
44+
def join(self, timeout=None):
3945
"""
4046
Wait for the thread to complete and return the result.
4147
4248
This method is overridden from the Thread class.
43-
It calls the parent class's join() method and then returns the stored return value.
49+
It calls the parent class's join() method, re-raises any captured
50+
exception, and then returns the stored return value.
4451
4552
Parameters:
46-
*args: Variable length argument list to pass to the join() method.
53+
timeout: Seconds to wait (None means wait indefinitely).
4754
4855
Returns:
4956
The return value from the target function.
57+
58+
Raises:
59+
BaseException: If the target function raised during run().
5060
"""
51-
# Call the parent class's join() method to wait for the thread to finish
52-
Thread.join(self, *args)
61+
Thread.join(self, timeout=timeout)
62+
if self._exception is not None:
63+
raise self._exception
5364
# Return the result of the target function
5465
return self._return

packages/pynumaflow/pynumaflow/sinker/server.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23

34
from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION
45
from pynumaflow.sinker.servicer.sync_servicer import SyncSinkServicer
@@ -120,6 +121,7 @@ def start(self):
120121
)
121122
serv_info = ServerInfo.get_default_server_info()
122123
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sinker]
124+
123125
# Start the server
124126
sync_server_start(
125127
servicer=self.servicer,
@@ -129,4 +131,9 @@ def start(self):
129131
server_options=self._server_options,
130132
udf_type=UDFType.Sink,
131133
server_info=serv_info,
134+
shutdown_event=self.servicer.shutdown_event,
132135
)
136+
137+
if self.servicer.error:
138+
_LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error)
139+
sys.exit(1)

packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import threading
12
from collections.abc import Iterator
23

4+
import grpc
35

4-
from pynumaflow._constants import _LOGGER, STREAM_EOF
6+
from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING
57
from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2
6-
from pynumaflow.shared.server import exit_on_error
8+
from pynumaflow.shared.server import update_context_err
79
from pynumaflow.shared.synciter import SyncIterator
810
from pynumaflow.shared.thread_with_return import ThreadWithReturnValue
911
from pynumaflow.sinker._dtypes import SinkSyncCallable
@@ -24,6 +26,8 @@ class SyncSinkServicer(sink_pb2_grpc.SinkServicer):
2426

2527
def __init__(self, handler: SinkSyncCallable):
2628
self.handler: SinkSyncCallable = handler
29+
self.shutdown_event: threading.Event = threading.Event()
30+
self.error: BaseException | None = None
2731

2832
def SinkFn(
2933
self, request_iterator: Iterator[sink_pb2.SinkRequest], context: NumaflowServicerContext
@@ -32,6 +36,7 @@ def SinkFn(
3236
Applies a sink function to datum elements.
3337
"""
3438

39+
req_queue = None
3540
try:
3641
# The first message to be received should be a valid handshake
3742
req = next(request_iterator)
@@ -78,23 +83,31 @@ def SinkFn(
7883
if cur_task:
7984
cur_task.join()
8085

86+
except grpc.RpcError:
87+
_LOGGER.warning("gRPC stream closed, shutting down the server.")
88+
if req_queue is not None:
89+
req_queue.close()
90+
self.shutdown_event.set()
91+
return
92+
8193
except BaseException as err:
82-
# Handle exceptions
83-
err_msg = f"UDSinkError: {repr(err)}"
94+
err_msg = f"UDSinkError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}"
8495
_LOGGER.critical(err_msg, exc_info=True)
85-
exit_on_error(context, err_msg)
96+
update_context_err(context, err, err_msg)
97+
# Unblock the handler thread if it is waiting on queue.get()
98+
# (e.g. gRPC stream broke while the handler was waiting for the next message).
99+
# This lets it exit gracefully and release any user-held resources
100+
# before the process shuts down.
101+
if req_queue is not None:
102+
req_queue.close()
103+
self.error = err
104+
self.shutdown_event.set()
86105
return
87106

88107
def _invoke_sink(self, request_queue: SyncIterator, context: NumaflowServicerContext):
89-
try:
90-
# Invoke the handler function with the request queue
91-
rspns = self.handler(request_queue.read_iterator())
92-
return build_sink_resp_results(rspns)
93-
except BaseException as err:
94-
err_msg = f"UDSinkError: {repr(err)}"
95-
_LOGGER.critical(err_msg, exc_info=True)
96-
exit_on_error(context, err_msg)
97-
raise err
108+
# Invoke the handler function with the request queue
109+
rspns = self.handler(request_queue.read_iterator())
110+
return build_sink_resp_results(rspns)
98111

99112
def IsReady(self, request, context: NumaflowServicerContext) -> sink_pb2.ReadyResponse:
100113
"""

0 commit comments

Comments
 (0)