Skip to content

Commit 5367fd3

Browse files
committed
Fixes for serialize_as_any in pydantic
Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com>
1 parent cc8f0ee commit 5367fd3

File tree

5 files changed

+144
-111
lines changed

5 files changed

+144
-111
lines changed

ccflow/base.py

Lines changed: 17 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,14 @@
22

33
import collections.abc
44
import copy
5-
import inspect
65
import logging
76
import pathlib
8-
import platform
97
import sys
108
import warnings
11-
from types import GenericAlias, MappingProxyType
12-
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
9+
from types import MappingProxyType
10+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar
1311

1412
from omegaconf import DictConfig
15-
from packaging import version
1613
from pydantic import (
1714
BaseModel as PydanticBaseModel,
1815
ConfigDict,
@@ -89,66 +86,7 @@ def get_registry_dependencies(self, types: Optional[Tuple["ModelType"]] = None)
8986
return deps
9087

9188

92-
# Pydantic 2 has different handling of serialization.
93-
# This requires some workarounds at the moment until the feature is added to easily get a mode that
94-
# is compatible with Pydantic 1
95-
# This is done by adjusting annotations via a MetaClass for any annotation that includes a BaseModel,
96-
# such that the new annotation contains SerializeAsAny
97-
# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
98-
# https://github.com/pydantic/pydantic/issues/6423
99-
# https://github.com/pydantic/pydantic-core/pull/740
100-
# See https://github.com/pydantic/pydantic/issues/6381 for inspiration on implementation
101-
# NOTE: For this logic to be removed, require https://github.com/pydantic/pydantic-core/pull/1478
102-
from pydantic._internal._model_construction import ModelMetaclass # noqa: E402
103-
104-
_IS_PY39 = version.parse(platform.python_version()) < version.parse("3.10")
105-
106-
107-
def _adjust_annotations(annotation):
108-
origin = get_origin(annotation)
109-
args = get_args(annotation)
110-
if not _IS_PY39:
111-
from types import UnionType
112-
113-
if origin is UnionType:
114-
origin = Union
115-
116-
if isinstance(annotation, GenericAlias) or (inspect.isclass(annotation) and issubclass(annotation, PydanticBaseModel)):
117-
return SerializeAsAny[annotation]
118-
elif origin and args:
119-
# Filter out typing.Type and generic types
120-
if origin is type or (inspect.isclass(origin) and issubclass(origin, Generic)):
121-
return annotation
122-
elif origin is ClassVar: # ClassVar doesn't accept a tuple of length 1 in py39
123-
return ClassVar[_adjust_annotations(args[0])]
124-
else:
125-
try:
126-
return origin[tuple(_adjust_annotations(arg) for arg in args)]
127-
except TypeError:
128-
raise TypeError(f"Could not adjust annotations for {origin}")
129-
else:
130-
return annotation
131-
132-
133-
class _SerializeAsAnyMeta(ModelMetaclass):
134-
def __new__(self, name: str, bases: Tuple[type], namespaces: Dict[str, Any], **kwargs):
135-
annotations: dict = namespaces.get("__annotations__", {})
136-
137-
for base in bases:
138-
for base_ in base.__mro__:
139-
if base_ is PydanticBaseModel:
140-
annotations.update(base_.__annotations__)
141-
142-
for field, annotation in annotations.items():
143-
if not field.startswith("__"):
144-
annotations[field] = _adjust_annotations(annotation)
145-
146-
namespaces["__annotations__"] = annotations
147-
148-
return super().__new__(self, name, bases, namespaces, **kwargs)
149-
150-
151-
class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta):
89+
class BaseModel(PydanticBaseModel, _RegistryMixin):
15290
"""BaseModel is a base class for all pydantic models within the cubist flow framework.
15391
15492
This gives us a way to add functionality to the framework, including
@@ -182,6 +120,17 @@ def type_(self) -> PyObjectPath:
182120
ser_json_timedelta="float",
183121
)
184122

123+
# https://docs.pydantic.dev/latest/concepts/serialization/#overriding-the-serialize_as_any-default-false
124+
def model_dump(self, **kwargs) -> dict[str, Any]:
125+
if not kwargs.get("serialize_as_any"):
126+
kwargs["serialize_as_any"] = True
127+
return super().model_dump(**kwargs)
128+
129+
def model_dump_json(self, **kwargs) -> str:
130+
if not kwargs.get("serialize_as_any"):
131+
kwargs["serialize_as_any"] = True
132+
return super().model_dump_json(**kwargs)
133+
185134
def __str__(self):
186135
# Because the standard string representation does not include class name
187136
return repr(self)
@@ -251,7 +200,7 @@ def _base_model_validator(cls, v, handler, info):
251200

252201
if isinstance(v, PydanticBaseModel):
253202
# Coerce from one BaseModel type to another (because it worked automatically in v1)
254-
v = v.model_dump(exclude={"type_"})
203+
v = v.model_dump(serialize_as_any=True, exclude={"type_"})
255204

256205
return handler(v)
257206

@@ -376,7 +325,8 @@ def _validate_name(cls, v):
376325
@model_serializer(mode="wrap")
377326
def _registry_serializer(self, handler):
378327
values = handler(self)
379-
values["models"] = self._models
328+
models_serialized = {k: model.model_dump(serialize_as_any=True, by_alias=True) for k, model in self._models.items()}
329+
values["models"] = models_serialized
380330
return values
381331

382332
@property

ccflow/callable.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,18 @@
1717
from inspect import Signature, isclass, signature
1818
from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar
1919

20-
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator
20+
from pydantic import (
21+
BaseModel as PydanticBaseModel,
22+
ConfigDict,
23+
Field,
24+
InstanceOf,
25+
PrivateAttr,
26+
SerializerFunctionWrapHandler,
27+
TypeAdapter,
28+
field_validator,
29+
model_serializer,
30+
model_validator,
31+
)
2132
from typing_extensions import override
2233

2334
from .base import (
@@ -426,6 +437,36 @@ def __call__(self) -> ResultType:
426437
else:
427438
return fn(self.context)
428439

440+
# When serialize_as_any=True, pydantic may detect repeated object ids in nested graphs
441+
# (e.g., shared default lists) and raise a circular reference error during serialization.
442+
# For computing cache keys, fall back to a minimal, stable representation if such an error occurs.
443+
# This is similar to how we the pydantic docs here:
444+
# https://docs.pydantic.dev/latest/concepts/forward_annotations/#cyclic-references
445+
# handle cyclic references during serialization.
446+
@model_serializer(mode="wrap")
447+
def _serialize_model_evaluation_context(self, handler: SerializerFunctionWrapHandler):
448+
try:
449+
return handler(self)
450+
except ValueError as exc:
451+
msg = str(exc)
452+
if "Circular reference" not in msg and "id repeated" not in msg:
453+
raise
454+
# Minimal, stable representation sufficient for cache-key tokenization
455+
try:
456+
model_repr = self.model.model_dump(mode="python", serialize_as_any=True, by_alias=True)
457+
except Exception:
458+
model_repr = repr(self.model)
459+
try:
460+
context_repr = self.context.model_dump(mode="python", serialize_as_any=True, by_alias=True)
461+
except Exception:
462+
context_repr = repr(self.context)
463+
return dict(
464+
fn=self.fn,
465+
model=model_repr,
466+
context=context_repr,
467+
options=dict(self.options),
468+
)
469+
429470

430471
class EvaluatorBase(_CallableModel, abc.ABC):
431472
"""Base class for evaluators, which are higher-order models that evaluate ModelAndContext.

ccflow/tests/test_base_serialize.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import pickle
2-
import platform
32
import unittest
4-
from typing import Annotated, ClassVar, Dict, List, Optional, Type, Union
3+
from typing import Annotated, Optional
54

65
import numpy as np
7-
from packaging import version
86
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, ValidationError
97

108
from ccflow import BaseModel, NDArray
@@ -213,45 +211,6 @@ class C(PydanticBaseModel):
213211
# C implements the normal pydantic BaseModel whichhould allow extra fields.
214212
_ = C(extra_field1=1)
215213

216-
def test_serialize_as_any(self):
217-
# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
218-
# https://github.com/pydantic/pydantic/issues/6423
219-
# This test could be removed once there is a different solution to the issue above
220-
from pydantic import SerializeAsAny
221-
from pydantic.types import constr
222-
223-
if version.parse(platform.python_version()) >= version.parse("3.10"):
224-
pipe_union = A | int
225-
else:
226-
pipe_union = Union[A, int]
227-
228-
class MyNestedModel(BaseModel):
229-
a1: A
230-
a2: Optional[Union[A, int]]
231-
a3: Dict[str, Optional[List[A]]]
232-
a4: ClassVar[A]
233-
a5: Type[A]
234-
a6: constr(min_length=1)
235-
a7: pipe_union
236-
237-
target = {
238-
"a1": SerializeAsAny[A],
239-
"a2": Optional[Union[SerializeAsAny[A], int]],
240-
"a4": ClassVar[SerializeAsAny[A]],
241-
"a5": Type[A],
242-
"a6": constr(min_length=1), # Uses Annotation
243-
"a7": Union[SerializeAsAny[A], int],
244-
}
245-
target["a3"] = dict[str, Optional[list[SerializeAsAny[A]]]]
246-
annotations = MyNestedModel.__annotations__
247-
self.assertEqual(str(annotations["a1"]), str(target["a1"]))
248-
self.assertEqual(str(annotations["a2"]), str(target["a2"]))
249-
self.assertEqual(str(annotations["a3"]), str(target["a3"]))
250-
self.assertEqual(str(annotations["a4"]), str(target["a4"]))
251-
self.assertEqual(str(annotations["a5"]), str(target["a5"]))
252-
self.assertEqual(str(annotations["a6"]), str(target["a6"]))
253-
self.assertEqual(str(annotations["a7"]), str(target["a7"]))
254-
255214
def test_pickle_consistency(self):
256215
model = MultiAttributeModel(z=1, y="test", x=3.14, w=True)
257216
serialized = pickle.dumps(model)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import json
2+
from datetime import date
3+
4+
from ccflow import DateContext
5+
from ccflow.callable import ModelEvaluationContext
6+
from ccflow.evaluators import GraphEvaluator, LoggingEvaluator, MultiEvaluator
7+
from ccflow.tests.evaluators.util import NodeModel
8+
9+
# NOTE: for these tests, round-tripping via JSON does not work
10+
# because the ModelEvaluationContext just has an InstanceOf validation check
11+
# and so we do not actually construct a full MEC on load.
12+
13+
14+
def _make_nested_mec(model):
15+
ctx = DateContext(date=date(2022, 1, 1))
16+
mec = model.__call__.get_evaluation_context(model, ctx)
17+
assert isinstance(mec, ModelEvaluationContext)
18+
# ensure nested: outer model is an evaluator, inner is a ModelEvaluationContext
19+
assert isinstance(mec.context, ModelEvaluationContext)
20+
return mec
21+
22+
23+
def test_mec_model_dump_basic():
24+
m = NodeModel()
25+
mec = _make_nested_mec(m)
26+
27+
d = mec.model_dump()
28+
assert isinstance(d, dict)
29+
assert "fn" in d and "model" in d and "context" in d and "options" in d
30+
31+
s = mec.model_dump_json()
32+
parsed = json.loads(s)
33+
assert parsed["fn"] == d["fn"]
34+
# Also verify mode-specific dumps
35+
d_py = mec.model_dump(mode="python")
36+
assert isinstance(d_py, dict)
37+
d_json = mec.model_dump(mode="json")
38+
assert isinstance(d_json, dict)
39+
json.dumps(d_json)
40+
41+
42+
def test_mec_model_dump_diamond_graph():
43+
n0 = NodeModel()
44+
n1 = NodeModel(deps_model=[n0])
45+
n2 = NodeModel(deps_model=[n0])
46+
root = NodeModel(deps_model=[n1, n2])
47+
48+
mec = _make_nested_mec(root)
49+
50+
d = mec.model_dump()
51+
assert isinstance(d, dict)
52+
assert set(["fn", "model", "context", "options"]).issubset(d.keys())
53+
54+
s = mec.model_dump_json()
55+
json.loads(s)
56+
# verify mode dumps
57+
d_py = mec.model_dump(mode="python")
58+
assert isinstance(d_py, dict)
59+
d_json = mec.model_dump(mode="json")
60+
assert isinstance(d_json, dict)
61+
json.dumps(d_json)
62+
63+
64+
def test_mec_model_dump_with_multi_evaluator():
65+
m = NodeModel()
66+
_ = LoggingEvaluator() # ensure import/validation
67+
evaluator = MultiEvaluator(evaluators=[LoggingEvaluator(), GraphEvaluator()])
68+
69+
# Simulate how Flow builds evaluation context with a custom evaluator
70+
ctx = DateContext(date=date(2022, 1, 1))
71+
mec = ModelEvaluationContext(model=evaluator, context=m.__call__.get_evaluation_context(m, ctx))
72+
73+
d = mec.model_dump()
74+
assert isinstance(d, dict)
75+
assert "fn" in d and "model" in d and "context" in d
76+
s = mec.model_dump_json()
77+
json.loads(s)
78+
# verify mode dumps
79+
d_py = mec.model_dump(mode="python")
80+
assert isinstance(d_py, dict)
81+
d_json = mec.model_dump(mode="json")
82+
assert isinstance(d_json, dict)
83+
json.dumps(d_json)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ dependencies = [
4444
"orjson",
4545
"pandas",
4646
"pyarrow",
47-
"pydantic>=2.6,<3",
47+
"pydantic>=2.35,<3",
4848
"smart_open",
4949
"tenacity",
5050
]

0 commit comments

Comments
 (0)