Skip to content

Commit 63b0b0c

Browse files
authored
run callbacks in a threadpool, less syscalls in rawsocket (#212)
* run callbacks in a threadpool, less syscalls in rawsocket * add Session.run_forever() to let app run * set threadpool size to cpu_count * 4 * async: only use task in callbacks code
1 parent e39440f commit 63b0b0c

File tree

3 files changed

+220
-113
lines changed

3 files changed

+220
-113
lines changed

xconn/async_session.py

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,49 @@ def __init__(self, base_session: types.IAsyncBaseSession):
7676
self._loop = get_event_loop()
7777
self.wait_task = self._loop.create_task(self._wait())
7878

79+
async def _handle_invocation(
80+
self,
81+
msg: messages.Invocation,
82+
endpoint: Union[
83+
Callable[[types.Invocation], types.Result], Callable[[types.Invocation], Awaitable[types.Result]]
84+
],
85+
):
86+
try:
87+
result = await endpoint(types.Invocation(msg.args, msg.kwargs, msg.details))
88+
89+
if result is None:
90+
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
91+
elif isinstance(result, types.Result):
92+
data = self._session.send_message(
93+
messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details))
94+
)
95+
else:
96+
message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str(
97+
type(result)
98+
)
99+
msg_to_send = messages.Error(
100+
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
101+
)
102+
data = self._session.send_message(msg_to_send)
103+
104+
await self._base_session.send(data)
105+
except ApplicationError as e:
106+
msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args))
107+
data = self._session.send_message(msg_to_send)
108+
await self._base_session.send(data)
109+
except Exception as e:
110+
msg_to_send = messages.Error(
111+
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()])
112+
)
113+
data = self._session.send_message(msg_to_send)
114+
await self._base_session.send(data)
115+
116+
async def _handle_event(self, msg: messages.Event, endpoint: Callable[[types.Event], Awaitable[None]]):
117+
try:
118+
await endpoint(types.Event(msg.args, msg.kwargs, msg.details))
119+
except Exception as e:
120+
print(e)
121+
79122
async def _wait(self):
80123
while await self._base_session.transport.is_connected():
81124
try:
@@ -84,12 +127,11 @@ async def _wait(self):
84127
print(e)
85128
break
86129

87-
task = self._loop.create_task(self._process_incoming_message(self._session.receive(data)))
88-
self._tasks.add(task)
89-
task.add_done_callback(self._tasks.discard)
130+
await self._process_incoming_message(self._session.receive(data))
90131

91-
for callback in self._disconnect_callback:
92-
await callback()
132+
if self._disconnect_callback:
133+
callbacks = [callback() for callback in self._disconnect_callback]
134+
await asyncio.gather(*callbacks)
93135

94136
async def _process_incoming_message(self, msg: messages.Message):
95137
if isinstance(msg, messages.Registered):
@@ -104,36 +146,10 @@ async def _process_incoming_message(self, msg: messages.Message):
104146
request = self._call_requests.pop(msg.request_id)
105147
request.set_result(types.Result(msg.args, msg.kwargs, msg.details))
106148
elif isinstance(msg, messages.Invocation):
107-
try:
108-
endpoint = self._registrations[msg.registration_id]
109-
result = await endpoint(types.Invocation(msg.args, msg.kwargs, msg.details))
110-
111-
if result is None:
112-
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
113-
elif isinstance(result, types.Result):
114-
data = self._session.send_message(
115-
messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details))
116-
)
117-
else:
118-
message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str(
119-
type(result)
120-
)
121-
msg_to_send = messages.Error(
122-
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
123-
)
124-
data = self._session.send_message(msg_to_send)
125-
126-
await self._base_session.send(data)
127-
except ApplicationError as e:
128-
msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args))
129-
data = self._session.send_message(msg_to_send)
130-
await self._base_session.send(data)
131-
except Exception as e:
132-
msg_to_send = messages.Error(
133-
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()])
134-
)
135-
data = self._session.send_message(msg_to_send)
136-
await self._base_session.send(data)
149+
endpoint = self._registrations[msg.registration_id]
150+
task = self._loop.create_task(self._handle_invocation(msg, endpoint))
151+
self._tasks.add(task)
152+
task.add_done_callback(self._tasks.discard)
137153
elif isinstance(msg, messages.Subscribed):
138154
request = self._subscribe_requests.pop(msg.request_id)
139155
self._subscriptions[msg.subscription_id] = request.endpoint
@@ -147,10 +163,9 @@ async def _process_incoming_message(self, msg: messages.Message):
147163
request.set_result(None)
148164
elif isinstance(msg, messages.Event):
149165
endpoint = self._subscriptions[msg.subscription_id]
150-
try:
151-
await endpoint(types.Event(msg.args, msg.kwargs, msg.details))
152-
except Exception as e:
153-
print(e)
166+
task = self._loop.create_task(self._handle_event(msg, endpoint))
167+
self._tasks.add(task)
168+
task.add_done_callback(self._tasks.discard)
154169
elif isinstance(msg, messages.Error):
155170
match msg.message_type:
156171
case messages.Call.TYPE:

xconn/session.py

Lines changed: 67 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3-
from concurrent.futures import Future
4-
from threading import Thread
3+
from concurrent.futures import Future, ThreadPoolExecutor, wait
4+
import threading
5+
from os import cpu_count
56
from typing import Callable, Any
67
from dataclasses import dataclass
78

@@ -67,8 +68,12 @@ def __init__(self, base_session: types.BaseSession):
6768
self._session = session.WAMPSession(base_session.serializer)
6869

6970
self._disconnect_callback: list[Callable[[], None] | None] = []
71+
self._stopped = threading.Event()
7072

71-
thread = Thread(target=self._wait, daemon=False)
73+
# callback executor thread-pool
74+
self._executor = ThreadPoolExecutor(max_workers=(cpu_count() or 1) * 4)
75+
76+
thread = threading.Thread(target=self._wait, daemon=True)
7277
thread.start()
7378

7479
def _wait(self):
@@ -80,8 +85,54 @@ def _wait(self):
8085

8186
self._process_incoming_message(self._session.receive(data))
8287

83-
for callback in self._disconnect_callback:
84-
callback()
88+
# Shut down executor, cancelling anything still running
89+
self._executor.shutdown(cancel_futures=True, wait=False)
90+
91+
if self._disconnect_callback:
92+
with ThreadPoolExecutor(max_workers=len(self._disconnect_callback)) as executor:
93+
# Trigger disconnect callbacks concurrently
94+
futures = [executor.submit(cb) for cb in self._disconnect_callback]
95+
# Wait up to 1 second for them to finish
96+
wait(futures, timeout=1)
97+
98+
self._stopped.set()
99+
100+
def _handle_invocation(self, msg: messages.Invocation, endpoint: Callable[[types.Invocation], types.Result]):
101+
try:
102+
result = endpoint(types.Invocation(msg.args, msg.kwargs, msg.details))
103+
104+
if result is None:
105+
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
106+
elif isinstance(result, types.Result):
107+
data = self._session.send_message(
108+
messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details))
109+
)
110+
else:
111+
message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str(
112+
type(result)
113+
)
114+
msg_to_send = messages.Error(
115+
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
116+
)
117+
data = self._session.send_message(msg_to_send)
118+
119+
self._base_session.send(data)
120+
except ApplicationError as e:
121+
msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args))
122+
data = self._session.send_message(msg_to_send)
123+
self._base_session.send(data)
124+
except Exception as e:
125+
msg_to_send = messages.Error(
126+
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()])
127+
)
128+
data = self._session.send_message(msg_to_send)
129+
self._base_session.send(data)
130+
131+
def _handle_event(self, msg: messages.Event, endpoint: Callable[[types.Event], None]):
132+
try:
133+
endpoint(types.Event(msg.args, msg.kwargs, msg.details))
134+
except Exception as e:
135+
print(e)
85136

86137
def _process_incoming_message(self, msg: messages.Message):
87138
if isinstance(msg, messages.Registered):
@@ -98,28 +149,7 @@ def _process_incoming_message(self, msg: messages.Message):
98149
elif isinstance(msg, messages.Invocation):
99150
try:
100151
endpoint = self._registrations[msg.registration_id]
101-
result = endpoint(types.Invocation(msg.args, msg.kwargs, msg.details))
102-
103-
if result is None:
104-
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
105-
elif isinstance(result, types.Result):
106-
data = self._session.send_message(
107-
messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details))
108-
)
109-
else:
110-
message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str(
111-
type(result)
112-
)
113-
msg_to_send = messages.Error(
114-
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
115-
)
116-
data = self._session.send_message(msg_to_send)
117-
118-
self._base_session.send(data)
119-
except ApplicationError as e:
120-
msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args))
121-
data = self._session.send_message(msg_to_send)
122-
self._base_session.send(data)
152+
self._executor.submit(self._handle_invocation, msg, endpoint)
123153
except Exception as e:
124154
msg_to_send = messages.Error(
125155
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()])
@@ -140,7 +170,7 @@ def _process_incoming_message(self, msg: messages.Message):
140170
elif isinstance(msg, messages.Event):
141171
try:
142172
endpoint = self._subscriptions[msg.subscription_id]
143-
endpoint(types.Event(msg.args, msg.kwargs, msg.details))
173+
self._executor.submit(self._handle_event, msg, endpoint)
144174
except Exception as e:
145175
print(e)
146176
elif isinstance(msg, messages.Error):
@@ -295,3 +325,12 @@ def ping(self, timeout: int = 10) -> float:
295325
def _on_disconnect(self, callback: Callable[[], None]) -> None:
296326
if callback is not None:
297327
self._disconnect_callback.append(callback)
328+
329+
def run_forever(self):
330+
"""Block until the session is closed/disconnected."""
331+
print("[Session] Running forever — press Ctrl+C to exit.")
332+
try:
333+
self._stopped.wait()
334+
except KeyboardInterrupt:
335+
print("[Session] Interrupted — shutting down...")
336+
self.leave()

0 commit comments

Comments
 (0)