|
12 | 12 | from codex.protocol import types as protocol |
13 | 13 |
|
14 | 14 | type RequestHandler[RequestT: BaseModel] = Callable[[RequestT], object | Awaitable[object]] |
15 | | -Notification = BaseModel |
| 15 | +type Notification = protocol.ServerNotificationValue | GenericNotification |
| 16 | +type ServerRequest = protocol.ServerRequestValue | GenericServerRequest |
16 | 17 |
|
17 | 18 |
|
18 | 19 | def method_name(message: BaseModel) -> str: |
@@ -123,7 +124,7 @@ def parse_notification(message: JsonObject, *, strict: bool) -> Notification: |
123 | 124 | raise AppServerProtocolError(_notification_error_message(message)) from exc |
124 | 125 |
|
125 | 126 |
|
126 | | -def parse_server_request(message: JsonObject, *, strict: bool) -> BaseModel: |
| 127 | +def parse_server_request(message: JsonObject, *, strict: bool) -> ServerRequest: |
127 | 128 | method = message.get("method") |
128 | 129 | try: |
129 | 130 | return protocol.ServerRequest.model_validate(message).root |
@@ -151,16 +152,21 @@ def _build_known_methods(*, root_model: type[BaseModel]) -> frozenset[str]: |
151 | 152 | root_field = getattr(root_model, "model_fields", {}).get("root") |
152 | 153 | if root_field is None: |
153 | 154 | return frozenset() |
| 155 | + annotation = _unwrap_type_alias(root_field.annotation) |
154 | 156 | methods = { |
155 | 157 | method |
156 | | - for candidate in get_args(root_field.annotation) |
| 158 | + for candidate in get_args(annotation) |
157 | 159 | if isinstance(candidate, type) and issubclass(candidate, BaseModel) |
158 | 160 | for method in [_candidate_method_literal(candidate)] |
159 | 161 | if method is not None |
160 | 162 | } |
161 | 163 | return frozenset(methods) |
162 | 164 |
|
163 | 165 |
|
| 166 | +def _unwrap_type_alias(annotation: object) -> object: |
| 167 | + return getattr(annotation, "__value__", annotation) |
| 168 | + |
| 169 | + |
164 | 170 | def _candidate_method_literal(candidate: type[BaseModel]) -> str | None: |
165 | 171 | model_fields = getattr(candidate, "model_fields", None) |
166 | 172 | if not isinstance(model_fields, dict) or "method" not in model_fields: |
|
0 commit comments