66import re
77from collections .abc import Mapping
88from functools import reduce
9- from typing import Any , Literal , Optional , Union
9+ from typing import Annotated , Any , Literal , Optional , Union , get_args , get_origin
1010
1111import msgspec
1212import voluptuous
@@ -67,11 +67,12 @@ def validate_schema(schema, obj, msg_prefix):
6767 raise Exception (f"{ msg_prefix } \n { str (exc )} \n { pprint .pformat (obj )} " )
6868
6969
70- def UnionTypes (* types ):
71- """Use `functools.reduce` to simulate `Union[*allowed_types]` on older
72- Python versions.
73- """
74- return reduce (lambda a , b : Union [a , b ], types )
70+ class OptionallyKeyedBy :
71+ """Metadata class for optionally_keyed_by fields in msgspec schemas."""
72+
73+ def __init__ (self , * fields , wrapped_type ):
74+ self .fields = fields
75+ self .wrapped_type = wrapped_type
7576
7677
7778def optionally_keyed_by (* arguments , use_msgspec = False ):
@@ -83,13 +84,15 @@ def optionally_keyed_by(*arguments, use_msgspec=False):
8384 use_msgspec: If True, return msgspec type hints; if False, return voluptuous validator
8485 """
8586 if use_msgspec :
86- # msgspec implementation - return type hints
87+ # msgspec implementation - use Annotated[Any, OptionallyKeyedBy]
8788 _type = arguments [- 1 ]
8889 if _type is object :
8990 return object
9091 fields = arguments [:- 1 ]
91- bykeys = [Literal [f"by-{ field } " ] for field in fields ]
92- return Union [_type , dict [UnionTypes (* bykeys ), dict [str , Any ]]]
92+ wrapper = OptionallyKeyedBy (* fields , wrapped_type = _type )
93+ # Annotating Any allows msgspec to accept any value without validation.
94+ # The actual validation then happens in Schema.__post_init__
95+ return Annotated [Any , wrapper ]
9396 else :
9497 # voluptuous implementation - return validator function
9598 schema = arguments [- 1 ]
@@ -280,6 +283,13 @@ def __getitem__(self, item):
280283 return self .schema [item ] # type: ignore
281284
282285
286+ def UnionTypes (* types ):
287+ """Use `functools.reduce` to simulate `Union[*allowed_types]` on older
288+ Python versions.
289+ """
290+ return reduce (lambda a , b : Union [a , b ], types )
291+
292+
283293class Schema (
284294 msgspec .Struct ,
285295 kw_only = True ,
@@ -307,6 +317,48 @@ class MySchema(Schema, forbid_unknown_fields=False, kw_only=True):
307317 foo: str
308318 """
309319
320+ def __post_init__ (self ):
321+ if taskgraph .fast :
322+ return
323+
324+ # Validate fields that use optionally_keyed_by. We need to validate this
325+ # manually because msgspec doesn't support union types with multiple
326+ # dicts. Any fields that use `optionally_keyed_by("foo", dict)` would
327+ # otherwise raise an exception.
328+ for field_name , field_type in self .__class__ .__annotations__ .items ():
329+ origin = get_origin (field_type )
330+ args = get_args (field_type )
331+
332+ if (
333+ origin is not Annotated
334+ or len (args ) < 2
335+ or not isinstance (args [1 ], OptionallyKeyedBy )
336+ ):
337+ # Not using `optionally_keyed_by`
338+ continue
339+
340+ keyed_by = args [1 ]
341+ obj = getattr (self , field_name )
342+ if not isinstance (obj , dict ) or len (obj ) != 1 :
343+ # Not using keyed by, validate direclty against wrapped type
344+ msgspec .convert (obj , keyed_by .wrapped_type )
345+ continue
346+
347+ key , keyed_by_dict = next (iter (obj .items ()))
348+ if not key .startswith ("by-" ):
349+ # Not using keyed by, validate direclty against wrapped type
350+ msgspec .convert (obj , keyed_by .wrapped_type )
351+ continue
352+
353+ # First validate the outer keyed-by dict
354+ bykeys = UnionTypes (* [Literal [f"by-{ field } " ] for field in keyed_by .fields ])
355+ msgspec .convert (obj , dict [bykeys , dict [str , Any ]])
356+
357+ # Next validate each inner value against the wrapped type
358+ for pattern , value in keyed_by_dict .items ():
359+ msgspec .convert (pattern , str )
360+ msgspec .convert (value , keyed_by .wrapped_type )
361+
310362 @classmethod
311363 def validate (cls , data ):
312364 """Validate data against this schema."""
0 commit comments