Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 54 additions & 39 deletions xconn/async_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,49 @@ def __init__(self, base_session: types.IAsyncBaseSession):
self._loop = get_event_loop()
self.wait_task = self._loop.create_task(self._wait())

async def _handle_invocation(
self,
msg: messages.Invocation,
endpoint: Union[
Callable[[types.Invocation], types.Result], Callable[[types.Invocation], Awaitable[types.Result]]
],
):
try:
result = await endpoint(types.Invocation(msg.args, msg.kwargs, msg.details))

if result is None:
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
elif isinstance(result, types.Result):
data = self._session.send_message(
messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details))
)
else:
message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str(
type(result)
)
msg_to_send = messages.Error(
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
)
data = self._session.send_message(msg_to_send)

await self._base_session.send(data)
except ApplicationError as e:
msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args))
data = self._session.send_message(msg_to_send)
await self._base_session.send(data)
except Exception as e:
msg_to_send = messages.Error(
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()])
)
data = self._session.send_message(msg_to_send)
await self._base_session.send(data)

async def _handle_event(self, msg: messages.Event, endpoint: Callable[[types.Event], Awaitable[None]]):
try:
await endpoint(types.Event(msg.args, msg.kwargs, msg.details))
except Exception as e:
print(e)

async def _wait(self):
while await self._base_session.transport.is_connected():
try:
Expand All @@ -84,12 +127,11 @@ async def _wait(self):
print(e)
break

task = self._loop.create_task(self._process_incoming_message(self._session.receive(data)))
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
await self._process_incoming_message(self._session.receive(data))

for callback in self._disconnect_callback:
await callback()
if self._disconnect_callback:
callbacks = [callback() for callback in self._disconnect_callback]
await asyncio.gather(*callbacks)

async def _process_incoming_message(self, msg: messages.Message):
if isinstance(msg, messages.Registered):
Expand All @@ -104,36 +146,10 @@ async def _process_incoming_message(self, msg: messages.Message):
request = self._call_requests.pop(msg.request_id)
request.set_result(types.Result(msg.args, msg.kwargs, msg.details))
elif isinstance(msg, messages.Invocation):
try:
endpoint = self._registrations[msg.registration_id]
result = await endpoint(types.Invocation(msg.args, msg.kwargs, msg.details))

if result is None:
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
elif isinstance(result, types.Result):
data = self._session.send_message(
messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details))
)
else:
message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str(
type(result)
)
msg_to_send = messages.Error(
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
)
data = self._session.send_message(msg_to_send)

await self._base_session.send(data)
except ApplicationError as e:
msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args))
data = self._session.send_message(msg_to_send)
await self._base_session.send(data)
except Exception as e:
msg_to_send = messages.Error(
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()])
)
data = self._session.send_message(msg_to_send)
await self._base_session.send(data)
endpoint = self._registrations[msg.registration_id]
task = self._loop.create_task(self._handle_invocation(msg, endpoint))
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
elif isinstance(msg, messages.Subscribed):
request = self._subscribe_requests.pop(msg.request_id)
self._subscriptions[msg.subscription_id] = request.endpoint
Expand All @@ -147,10 +163,9 @@ async def _process_incoming_message(self, msg: messages.Message):
request.set_result(None)
elif isinstance(msg, messages.Event):
endpoint = self._subscriptions[msg.subscription_id]
try:
await endpoint(types.Event(msg.args, msg.kwargs, msg.details))
except Exception as e:
print(e)
task = self._loop.create_task(self._handle_event(msg, endpoint))
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
elif isinstance(msg, messages.Error):
match msg.message_type:
case messages.Call.TYPE:
Expand Down
95 changes: 67 additions & 28 deletions xconn/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from concurrent.futures import Future
from threading import Thread
from concurrent.futures import Future, ThreadPoolExecutor, wait
import threading
from os import cpu_count
from typing import Callable, Any
from dataclasses import dataclass

Expand Down Expand Up @@ -67,8 +68,12 @@ def __init__(self, base_session: types.BaseSession):
self._session = session.WAMPSession(base_session.serializer)

self._disconnect_callback: list[Callable[[], None] | None] = []
self._stopped = threading.Event()

thread = Thread(target=self._wait, daemon=False)
# callback executor thread-pool
self._executor = ThreadPoolExecutor(max_workers=(cpu_count() or 1) * 4)

thread = threading.Thread(target=self._wait, daemon=True)
thread.start()

def _wait(self):
Expand All @@ -80,8 +85,54 @@ def _wait(self):

self._process_incoming_message(self._session.receive(data))

for callback in self._disconnect_callback:
callback()
# Shut down executor, cancelling anything still running
self._executor.shutdown(cancel_futures=True, wait=False)

if self._disconnect_callback:
with ThreadPoolExecutor(max_workers=len(self._disconnect_callback)) as executor:
# Trigger disconnect callbacks concurrently
futures = [executor.submit(cb) for cb in self._disconnect_callback]
# Wait up to 1 second for them to finish
wait(futures, timeout=1)

self._stopped.set()

def _handle_invocation(self, msg: messages.Invocation, endpoint: Callable[[types.Invocation], types.Result]):
try:
result = endpoint(types.Invocation(msg.args, msg.kwargs, msg.details))

if result is None:
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
elif isinstance(result, types.Result):
data = self._session.send_message(
messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details))
)
else:
message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str(
type(result)
)
msg_to_send = messages.Error(
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
)
data = self._session.send_message(msg_to_send)

self._base_session.send(data)
except ApplicationError as e:
msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args))
data = self._session.send_message(msg_to_send)
self._base_session.send(data)
except Exception as e:
msg_to_send = messages.Error(
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()])
)
data = self._session.send_message(msg_to_send)
self._base_session.send(data)

def _handle_event(self, msg: messages.Event, endpoint: Callable[[types.Event], None]):
try:
endpoint(types.Event(msg.args, msg.kwargs, msg.details))
except Exception as e:
print(e)

def _process_incoming_message(self, msg: messages.Message):
if isinstance(msg, messages.Registered):
Expand All @@ -98,28 +149,7 @@ def _process_incoming_message(self, msg: messages.Message):
elif isinstance(msg, messages.Invocation):
try:
endpoint = self._registrations[msg.registration_id]
result = endpoint(types.Invocation(msg.args, msg.kwargs, msg.details))

if result is None:
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
elif isinstance(result, types.Result):
data = self._session.send_message(
messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details))
)
else:
message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str(
type(result)
)
msg_to_send = messages.Error(
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
)
data = self._session.send_message(msg_to_send)

self._base_session.send(data)
except ApplicationError as e:
msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args))
data = self._session.send_message(msg_to_send)
self._base_session.send(data)
self._executor.submit(self._handle_invocation, msg, endpoint)
except Exception as e:
msg_to_send = messages.Error(
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()])
Expand All @@ -140,7 +170,7 @@ def _process_incoming_message(self, msg: messages.Message):
elif isinstance(msg, messages.Event):
try:
endpoint = self._subscriptions[msg.subscription_id]
endpoint(types.Event(msg.args, msg.kwargs, msg.details))
self._executor.submit(self._handle_event, msg, endpoint)
except Exception as e:
print(e)
elif isinstance(msg, messages.Error):
Expand Down Expand Up @@ -295,3 +325,12 @@ def ping(self, timeout: int = 10) -> float:
def _on_disconnect(self, callback: Callable[[], None]) -> None:
if callback is not None:
self._disconnect_callback.append(callback)

def run_forever(self):
"""Block until the session is closed/disconnected."""
print("[Session] Running forever — press Ctrl+C to exit.")
try:
self._stopped.wait()
except KeyboardInterrupt:
print("[Session] Interrupted — shutting down...")
self.leave()
Loading