|
2 | 2 |
|
3 | 3 | import collections.abc |
4 | 4 | import copy |
5 | | -import inspect |
6 | 5 | import logging |
7 | 6 | import pathlib |
8 | | -import platform |
9 | 7 | import sys |
10 | 8 | import warnings |
11 | | -from types import GenericAlias, MappingProxyType |
12 | | -from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin |
| 9 | +from types import MappingProxyType |
| 10 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar |
13 | 11 |
|
14 | 12 | from omegaconf import DictConfig |
15 | | -from packaging import version |
16 | 13 | from pydantic import ( |
17 | 14 | BaseModel as PydanticBaseModel, |
18 | 15 | ConfigDict, |
@@ -89,66 +86,7 @@ def get_registry_dependencies(self, types: Optional[Tuple["ModelType"]] = None) |
89 | 86 | return deps |
90 | 87 |
|
91 | 88 |
|
92 | | -# Pydantic 2 has different handling of serialization. |
93 | | -# This requires some workarounds at the moment until the feature is added to easily get a mode that |
94 | | -# is compatible with Pydantic 1 |
95 | | -# This is done by adjusting annotations via a MetaClass for any annotation that includes a BaseModel, |
96 | | -# such that the new annotation contains SerializeAsAny |
97 | | -# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing |
98 | | -# https://github.com/pydantic/pydantic/issues/6423 |
99 | | -# https://github.com/pydantic/pydantic-core/pull/740 |
100 | | -# See https://github.com/pydantic/pydantic/issues/6381 for inspiration on implementation |
101 | | -# NOTE: For this logic to be removed, require https://github.com/pydantic/pydantic-core/pull/1478 |
102 | | -from pydantic._internal._model_construction import ModelMetaclass # noqa: E402 |
103 | | - |
104 | | -_IS_PY39 = version.parse(platform.python_version()) < version.parse("3.10") |
105 | | - |
106 | | - |
107 | | -def _adjust_annotations(annotation): |
108 | | - origin = get_origin(annotation) |
109 | | - args = get_args(annotation) |
110 | | - if not _IS_PY39: |
111 | | - from types import UnionType |
112 | | - |
113 | | - if origin is UnionType: |
114 | | - origin = Union |
115 | | - |
116 | | - if isinstance(annotation, GenericAlias) or (inspect.isclass(annotation) and issubclass(annotation, PydanticBaseModel)): |
117 | | - return SerializeAsAny[annotation] |
118 | | - elif origin and args: |
119 | | - # Filter out typing.Type and generic types |
120 | | - if origin is type or (inspect.isclass(origin) and issubclass(origin, Generic)): |
121 | | - return annotation |
122 | | - elif origin is ClassVar: # ClassVar doesn't accept a tuple of length 1 in py39 |
123 | | - return ClassVar[_adjust_annotations(args[0])] |
124 | | - else: |
125 | | - try: |
126 | | - return origin[tuple(_adjust_annotations(arg) for arg in args)] |
127 | | - except TypeError: |
128 | | - raise TypeError(f"Could not adjust annotations for {origin}") |
129 | | - else: |
130 | | - return annotation |
131 | | - |
132 | | - |
133 | | -class _SerializeAsAnyMeta(ModelMetaclass): |
134 | | - def __new__(self, name: str, bases: Tuple[type], namespaces: Dict[str, Any], **kwargs): |
135 | | - annotations: dict = namespaces.get("__annotations__", {}) |
136 | | - |
137 | | - for base in bases: |
138 | | - for base_ in base.__mro__: |
139 | | - if base_ is PydanticBaseModel: |
140 | | - annotations.update(base_.__annotations__) |
141 | | - |
142 | | - for field, annotation in annotations.items(): |
143 | | - if not field.startswith("__"): |
144 | | - annotations[field] = _adjust_annotations(annotation) |
145 | | - |
146 | | - namespaces["__annotations__"] = annotations |
147 | | - |
148 | | - return super().__new__(self, name, bases, namespaces, **kwargs) |
149 | | - |
150 | | - |
151 | | -class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta): |
| 89 | +class BaseModel(PydanticBaseModel, _RegistryMixin): |
152 | 90 | """BaseModel is a base class for all pydantic models within the cubist flow framework. |
153 | 91 |
|
154 | 92 | This gives us a way to add functionality to the framework, including |
@@ -182,6 +120,17 @@ def type_(self) -> PyObjectPath: |
182 | 120 | ser_json_timedelta="float", |
183 | 121 | ) |
184 | 122 |
|
| 123 | + # https://docs.pydantic.dev/latest/concepts/serialization/#overriding-the-serialize_as_any-default-false |
| 124 | + def model_dump(self, **kwargs) -> dict[str, Any]: |
| 125 | + if not kwargs.get("serialize_as_any"): |
| 126 | + kwargs["serialize_as_any"] = True |
| 127 | + return super().model_dump(**kwargs) |
| 128 | + |
| 129 | + def model_dump_json(self, **kwargs) -> str: |
| 130 | + if not kwargs.get("serialize_as_any"): |
| 131 | + kwargs["serialize_as_any"] = True |
| 132 | + return super().model_dump_json(**kwargs) |
| 133 | + |
185 | 134 | def __str__(self): |
186 | 135 | # Because the standard string representation does not include class name |
187 | 136 | return repr(self) |
@@ -251,7 +200,7 @@ def _base_model_validator(cls, v, handler, info): |
251 | 200 |
|
252 | 201 | if isinstance(v, PydanticBaseModel): |
253 | 202 | # Coerce from one BaseModel type to another (because it worked automatically in v1) |
254 | | - v = v.model_dump(exclude={"type_"}) |
| 203 | + v = v.model_dump(serialize_as_any=True, exclude={"type_"}) |
255 | 204 |
|
256 | 205 | return handler(v) |
257 | 206 |
|
@@ -376,7 +325,8 @@ def _validate_name(cls, v): |
376 | 325 | @model_serializer(mode="wrap") |
377 | 326 | def _registry_serializer(self, handler): |
378 | 327 | values = handler(self) |
379 | | - values["models"] = self._models |
| 328 | + models_serialized = {k: model.model_dump(serialize_as_any=True, by_alias=True) for k, model in self._models.items()} |
| 329 | + values["models"] = models_serialized |
380 | 330 | return values |
381 | 331 |
|
382 | 332 | @property |
|
0 commit comments