|
2 | 2 |
|
3 | 3 | from concurrent.futures import Future |
4 | 4 | from threading import Thread |
5 | | -from typing import Callable, Any |
| 5 | +from typing import Callable, Any, TypeVar, Type |
6 | 6 | from dataclasses import dataclass |
7 | 7 |
|
8 | 8 | from wampproto import messages, session, uris |
9 | 9 |
|
10 | 10 | from xconn import types, exception, uris as xconn_uris |
| 11 | +from xconn.codec import Codec |
11 | 12 | from xconn.exception import ApplicationError |
12 | 13 | from xconn.helpers import exception_from_error, SessionScopeIDGenerator |
13 | 14 |
|
| 15 | +TReq = TypeVar("TReq") |
| 16 | +TRes = TypeVar("TRes") |
| 17 | + |
14 | 18 |
|
15 | 19 | @dataclass |
16 | 20 | class RegisterRequest: |
@@ -90,6 +94,8 @@ def __init__(self, base_session: types.BaseSession): |
90 | 94 |
|
91 | 95 | self._disconnect_callback: list[Callable[[], None] | None] = [] |
92 | 96 |
|
| 97 | + self._payload_codec: Codec = None |
| 98 | + |
93 | 99 | thread = Thread(target=self._wait, daemon=False) |
94 | 100 | thread.start() |
95 | 101 |
|
@@ -192,6 +198,18 @@ def _process_incoming_message(self, msg: messages.Message): |
192 | 198 | else: |
193 | 199 | raise ValueError("received unknown message") |
194 | 200 |
|
| 201 | + def set_payload_codec(self, codec: Codec) -> None: |
| 202 | + self._payload_codec = codec |
| 203 | + |
| 204 | + def call_object(self, procedure: str, request: TReq, return_type: Type[TRes] = None) -> TReq | None: |
| 205 | + if self._payload_codec is None: |
| 206 | + raise ValueError("no payload codec set") |
| 207 | + |
| 208 | + encoded = self._payload_codec.encode(request) |
| 209 | + result = self.call(procedure, [encoded]) |
| 210 | + |
| 211 | + return self._payload_codec.decode(result.args[0], return_type) |
| 212 | + |
195 | 213 | def call( |
196 | 214 | self, |
197 | 215 | procedure: str, |
|
0 commit comments