Skip to content

Commit 1226ac0

Browse files
committed
fix: support optionally_keyed_by with underlying dict
1 parent e23b624 commit 1226ac0

2 files changed

Lines changed: 70 additions & 10 deletions

File tree

src/taskgraph/util/schema.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
from collections.abc import Mapping
88
from functools import reduce
9-
from typing import Any, Literal, Optional, Union
9+
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin
1010

1111
import msgspec
1212
import voluptuous
@@ -67,11 +67,12 @@ def validate_schema(schema, obj, msg_prefix):
6767
raise Exception(f"{msg_prefix}\n{str(exc)}\n{pprint.pformat(obj)}")
6868

6969

70-
def UnionTypes(*types):
71-
"""Use `functools.reduce` to simulate `Union[*allowed_types]` on older
72-
Python versions.
73-
"""
74-
return reduce(lambda a, b: Union[a, b], types)
70+
class OptionallyKeyedBy:
71+
"""Metadata class for optionally_keyed_by fields in msgspec schemas."""
72+
73+
def __init__(self, *fields, wrapped_type):
74+
self.fields = fields
75+
self.wrapped_type = wrapped_type
7576

7677

7778
def optionally_keyed_by(*arguments, use_msgspec=False):
@@ -83,13 +84,15 @@ def optionally_keyed_by(*arguments, use_msgspec=False):
8384
use_msgspec: If True, return msgspec type hints; if False, return voluptuous validator
8485
"""
8586
if use_msgspec:
86-
# msgspec implementation - return type hints
87+
# msgspec implementation - use Annotated[Any, OptionallyKeyedBy]
8788
_type = arguments[-1]
8889
if _type is object:
8990
return object
9091
fields = arguments[:-1]
91-
bykeys = [Literal[f"by-{field}"] for field in fields]
92-
return Union[_type, dict[UnionTypes(*bykeys), dict[str, Any]]]
92+
wrapper = OptionallyKeyedBy(*fields, wrapped_type=_type)
93+
# Annotating Any allows msgspec to accept any value without validation.
94+
# The actual validation then happens in Schema.__post_init__
95+
return Annotated[Any, wrapper]
9396
else:
9497
# voluptuous implementation - return validator function
9598
schema = arguments[-1]
@@ -280,6 +283,13 @@ def __getitem__(self, item):
280283
return self.schema[item] # type: ignore
281284

282285

286+
def UnionTypes(*types):
287+
"""Use `functools.reduce` to simulate `Union[*allowed_types]` on older
288+
Python versions.
289+
"""
290+
return reduce(lambda a, b: Union[a, b], types)
291+
292+
283293
class Schema(
284294
msgspec.Struct,
285295
kw_only=True,
@@ -307,6 +317,48 @@ class MySchema(Schema, forbid_unknown_fields=False, kw_only=True):
307317
foo: str
308318
"""
309319

320+
def __post_init__(self):
321+
if taskgraph.fast:
322+
return
323+
324+
# Validate fields that use optionally_keyed_by. We need to validate this
325+
# manually because msgspec doesn't support union types with multiple
326+
# dicts. Any fields that use `optionally_keyed_by("foo", dict)` would
327+
# otherwise raise an exception.
328+
for field_name, field_type in self.__class__.__annotations__.items():
329+
origin = get_origin(field_type)
330+
args = get_args(field_type)
331+
332+
if (
333+
origin is not Annotated
334+
or len(args) < 2
335+
or not isinstance(args[1], OptionallyKeyedBy)
336+
):
337+
# Not using `optionally_keyed_by`
338+
continue
339+
340+
keyed_by = args[1]
341+
obj = getattr(self, field_name)
342+
if not isinstance(obj, dict) or len(obj) != 1:
343+
# Not using keyed by, validate direclty against wrapped type
344+
msgspec.convert(obj, keyed_by.wrapped_type)
345+
continue
346+
347+
key, keyed_by_dict = next(iter(obj.items()))
348+
if not key.startswith("by-"):
349+
# Not using keyed by, validate direclty against wrapped type
350+
msgspec.convert(obj, keyed_by.wrapped_type)
351+
continue
352+
353+
# First validate the outer keyed-by dict
354+
bykeys = UnionTypes(*[Literal[f"by-{field}"] for field in keyed_by.fields])
355+
msgspec.convert(obj, dict[bykeys, dict[str, Any]])
356+
357+
# Next validate each inner value against the wrapped type
358+
for pattern, value in keyed_by_dict.items():
359+
msgspec.convert(pattern, str)
360+
msgspec.convert(value, keyed_by.wrapped_type)
361+
310362
@classmethod
311363
def validate(cls, data):
312364
"""Validate data against this schema."""

test/test_util_schema.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ class TestSchema(Schema):
290290
TestSchema.validate({"field": "baz"})
291291
TestSchema.validate({"field": {"by-foo": {"a": "b", "c": "d"}}})
292292

293+
with pytest.raises(msgspec.ValidationError):
294+
TestSchema.validate({"field": 1})
295+
293296
with pytest.raises(msgspec.ValidationError):
294297
TestSchema.validate({"field": {"by-bar": "a"}})
295298

@@ -299,6 +302,9 @@ class TestSchema(Schema):
299302
with pytest.raises(msgspec.ValidationError):
300303
TestSchema.validate({"field": {"by-bar": {"a": "b"}}})
301304

305+
with pytest.raises(msgspec.ValidationError):
306+
TestSchema.validate({"field": {"by-foo": {"a": 1, "c": "d"}}})
307+
302308

303309
def test_optionally_keyed_by_mulitple_keys():
304310
class TestSchema(Schema):
@@ -323,14 +329,16 @@ def test_optionally_keyed_by_object_passthrough():
323329
assert msgspec.convert({"arbitrary": "dict"}, typ) == {"arbitrary": "dict"}
324330

325331

326-
@pytest.mark.xfail
327332
def test_optionally_keyed_by_dict():
328333
class TestSchema(Schema):
329334
field: optionally_keyed_by("foo", dict[str, str], use_msgspec=True) # type: ignore
330335

331336
TestSchema.validate({"field": {"by-foo": {"a": {"x": "y"}}}})
332337
TestSchema.validate({"field": {"a": "b"}})
333338

339+
with pytest.raises(msgspec.ValidationError):
340+
TestSchema.validate({"field": {"a": 1}})
341+
334342
with pytest.raises(msgspec.ValidationError):
335343
TestSchema.validate({"field": {"by-foo": {"a": {"x": 1}}}})
336344

0 commit comments

Comments
 (0)