11from __future__ import annotations
22
3- from concurrent .futures import Future
4- from threading import Thread
3+ from concurrent .futures import Future , ThreadPoolExecutor , wait
4+ import threading
5+ from os import cpu_count
56from typing import Callable , Any
67from dataclasses import dataclass
78
@@ -67,8 +68,12 @@ def __init__(self, base_session: types.BaseSession):
6768 self ._session = session .WAMPSession (base_session .serializer )
6869
6970 self ._disconnect_callback : list [Callable [[], None ] | None ] = []
71+ self ._stopped = threading .Event ()
7072
71- thread = Thread (target = self ._wait , daemon = False )
73+ # callback executor thread-pool
74+ self ._executor = ThreadPoolExecutor (max_workers = (cpu_count () or 1 ) * 4 )
75+
76+ thread = threading .Thread (target = self ._wait , daemon = True )
7277 thread .start ()
7378
7479 def _wait (self ):
@@ -80,8 +85,54 @@ def _wait(self):
8085
8186 self ._process_incoming_message (self ._session .receive (data ))
8287
83- for callback in self ._disconnect_callback :
84- callback ()
88+ # Shut down executor, cancelling anything still running
89+ self ._executor .shutdown (cancel_futures = True , wait = False )
90+
91+ if self ._disconnect_callback :
92+ with ThreadPoolExecutor (max_workers = len (self ._disconnect_callback )) as executor :
93+ # Trigger disconnect callbacks concurrently
94+ futures = [executor .submit (cb ) for cb in self ._disconnect_callback ]
95+ # Wait up to 1 second for them to finish
96+ wait (futures , timeout = 1 )
97+
98+ self ._stopped .set ()
99+
100+ def _handle_invocation (self , msg : messages .Invocation , endpoint : Callable [[types .Invocation ], types .Result ]):
101+ try :
102+ result = endpoint (types .Invocation (msg .args , msg .kwargs , msg .details ))
103+
104+ if result is None :
105+ data = self ._session .send_message (messages .Yield (messages .YieldFields (msg .request_id )))
106+ elif isinstance (result , types .Result ):
107+ data = self ._session .send_message (
108+ messages .Yield (messages .YieldFields (msg .request_id , result .args , result .kwargs , result .details ))
109+ )
110+ else :
111+ message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str (
112+ type (result )
113+ )
114+ msg_to_send = messages .Error (
115+ messages .ErrorFields (msg .TYPE , msg .request_id , xconn_uris .ERROR_INTERNAL_ERROR , [message ])
116+ )
117+ data = self ._session .send_message (msg_to_send )
118+
119+ self ._base_session .send (data )
120+ except ApplicationError as e :
121+ msg_to_send = messages .Error (messages .ErrorFields (msg .TYPE , msg .request_id , e .message , e .args ))
122+ data = self ._session .send_message (msg_to_send )
123+ self ._base_session .send (data )
124+ except Exception as e :
125+ msg_to_send = messages .Error (
126+ messages .ErrorFields (msg .TYPE , msg .request_id , xconn_uris .ERROR_RUNTIME_ERROR , [e .__str__ ()])
127+ )
128+ data = self ._session .send_message (msg_to_send )
129+ self ._base_session .send (data )
130+
131+ def _handle_event (self , msg : messages .Event , endpoint : Callable [[types .Event ], None ]):
132+ try :
133+ endpoint (types .Event (msg .args , msg .kwargs , msg .details ))
134+ except Exception as e :
135+ print (e )
85136
86137 def _process_incoming_message (self , msg : messages .Message ):
87138 if isinstance (msg , messages .Registered ):
@@ -98,28 +149,7 @@ def _process_incoming_message(self, msg: messages.Message):
98149 elif isinstance (msg , messages .Invocation ):
99150 try :
100151 endpoint = self ._registrations [msg .registration_id ]
101- result = endpoint (types .Invocation (msg .args , msg .kwargs , msg .details ))
102-
103- if result is None :
104- data = self ._session .send_message (messages .Yield (messages .YieldFields (msg .request_id )))
105- elif isinstance (result , types .Result ):
106- data = self ._session .send_message (
107- messages .Yield (messages .YieldFields (msg .request_id , result .args , result .kwargs , result .details ))
108- )
109- else :
110- message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str (
111- type (result )
112- )
113- msg_to_send = messages .Error (
114- messages .ErrorFields (msg .TYPE , msg .request_id , xconn_uris .ERROR_INTERNAL_ERROR , [message ])
115- )
116- data = self ._session .send_message (msg_to_send )
117-
118- self ._base_session .send (data )
119- except ApplicationError as e :
120- msg_to_send = messages .Error (messages .ErrorFields (msg .TYPE , msg .request_id , e .message , e .args ))
121- data = self ._session .send_message (msg_to_send )
122- self ._base_session .send (data )
152+ self ._executor .submit (self ._handle_invocation , msg , endpoint )
123153 except Exception as e :
124154 msg_to_send = messages .Error (
125155 messages .ErrorFields (msg .TYPE , msg .request_id , xconn_uris .ERROR_RUNTIME_ERROR , [e .__str__ ()])
@@ -140,7 +170,7 @@ def _process_incoming_message(self, msg: messages.Message):
140170 elif isinstance (msg , messages .Event ):
141171 try :
142172 endpoint = self ._subscriptions [msg .subscription_id ]
143- endpoint ( types . Event ( msg . args , msg . kwargs , msg . details ) )
173+ self . _executor . submit ( self . _handle_event , msg , endpoint )
144174 except Exception as e :
145175 print (e )
146176 elif isinstance (msg , messages .Error ):
@@ -295,3 +325,12 @@ def ping(self, timeout: int = 10) -> float:
295325 def _on_disconnect (self , callback : Callable [[], None ]) -> None :
296326 if callback is not None :
297327 self ._disconnect_callback .append (callback )
328+
329+ def run_forever (self ):
330+ """Block until the session is closed/disconnected."""
331+ print ("[Session] Running forever — press Ctrl+C to exit." )
332+ try :
333+ self ._stopped .wait ()
334+ except KeyboardInterrupt :
335+ print ("[Session] Interrupted — shutting down..." )
336+ self .leave ()
0 commit comments