|
32 | 32 |
|
33 | 33 | logger = logging.getLogger(__name__) |
34 | 34 |
|
| 35 | + |
| 36 | +class _DictSchemaAdapter: |
| 37 | + """Adapts a plain JSON Schema dict to the Pydantic model class interface. |
| 38 | +
|
| 39 | + Allows modules that define ``input_schema`` / ``output_schema`` as raw |
| 40 | + dicts to work transparently with the executor, schema exporter, and any |
| 41 | + other code that calls ``model_validate``, ``model_json_schema``, or |
| 42 | + ``model_rebuild`` on a schema object. |
| 43 | +
|
| 44 | + Note: ``model_validate`` is a pass-through — no JSON Schema validation is |
| 45 | + performed. Adding real validation would require a ``jsonschema`` dependency |
| 46 | + which is not currently declared. Modules that need strict input checking |
| 47 | + should use Pydantic model classes or validate inside ``execute()``. |
| 48 | + """ |
| 49 | + |
| 50 | + def __init__(self, schema: dict[str, Any]) -> None: |
| 51 | + self._schema = schema |
| 52 | + |
| 53 | + def model_json_schema(self) -> dict[str, Any]: |
| 54 | + return self._schema |
| 55 | + |
| 56 | + def model_validate(self, data: Any) -> Any: |
| 57 | + """Pass-through: returns *data* unchanged (no validation).""" |
| 58 | + return data |
| 59 | + |
| 60 | + def model_rebuild(self) -> None: |
| 61 | + pass |
| 62 | + |
| 63 | + |
| 64 | +def _ensure_schema_adapter(module: Any) -> None: |
| 65 | + """Wrap raw dict schemas on *module* with ``_DictSchemaAdapter`` in-place.""" |
| 66 | + for attr in ("input_schema", "output_schema"): |
| 67 | + value = getattr(module, attr, None) |
| 68 | + if isinstance(value, dict): |
| 69 | + setattr(module, attr, _DictSchemaAdapter(value)) |
| 70 | + |
| 71 | + |
35 | 72 | REGISTRY_EVENTS: dict[str, str] = { |
36 | 73 | "REGISTER": "register", |
37 | 74 | "UNREGISTER": "unregister", |
@@ -400,6 +437,8 @@ def register( |
400 | 437 | if len(module_id) > MAX_MODULE_ID_LENGTH: |
401 | 438 | raise InvalidInputError(f"Module ID exceeds maximum length of {MAX_MODULE_ID_LENGTH}: {len(module_id)}") |
402 | 439 |
|
| 440 | + _ensure_schema_adapter(module) |
| 441 | + |
403 | 442 | effective_version = version or getattr(module, "version", None) or "0.0.0" |
404 | 443 |
|
405 | 444 | is_versioned = version is not None |
@@ -974,6 +1013,7 @@ def register_internal(self, module_id: str, module: Any) -> None: |
974 | 1013 |
|
975 | 1014 | Used by sys modules that use the reserved 'system.' prefix. |
976 | 1015 | """ |
| 1016 | + _ensure_schema_adapter(module) |
977 | 1017 | with self._lock: |
978 | 1018 | self._modules[module_id] = module |
979 | 1019 | self._lowercase_map[module_id.lower()] = module_id |
|
0 commit comments