Skip to content
3 changes: 3 additions & 0 deletions ccflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,9 @@ class ContextBase(ResultBase):

@model_validator(mode="wrap")
def _context_validator(cls, v, handler, info):
if v is None:
return handler({})

# Add deepcopy for v2 because it doesn't support copy_on_model_validation
v = copy.deepcopy(v)

Expand Down
11 changes: 2 additions & 9 deletions ccflow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,8 @@

_SEPARATOR = ","


class NullContext(ContextBase):
"""A Null Context that is used when no context is provided."""

@model_validator(mode="wrap")
def _validate_none(cls, v, handler, info):
v = v or {}
return handler(v)

# Starting 0.8.0 Nullcontext is an alias to ContextBase
NullContext = ContextBase

C = TypeVar("C", bound=Hashable)

Expand Down
12 changes: 8 additions & 4 deletions ccflow/tests/test_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class MyExtendedContext(MyContext):
c: bool


class MyOtherContext(ContextBase):
a: int


class ListContext(ContextBase):
ll: List[str] = []

Expand Down Expand Up @@ -152,7 +156,7 @@ class BadModelMismatchedContextAndCall(CallableModel):

@property
def context_type(self):
return NullContext
return MyOtherContext

@property
def result_type(self):
Expand All @@ -163,7 +167,7 @@ def __call__(self, context: MyContext) -> MyResult:
return context


class BadModelGenericMismatchedContextAndCall(CallableModelGenericType[NullContext, MyResult]):
class BadModelGenericMismatchedContextAndCall(CallableModelGenericType[MyOtherContext, MyResult]):
"""Model with mismatched context_type and __call__ annotation"""

@Flow.call
Expand Down Expand Up @@ -460,7 +464,7 @@ def test_types(self):
error = "__call__ method must take a single argument, named 'context'"
self.assertRaisesRegex(ValueError, error, BadModelDoubleContextArg)

error = "The context_type <class 'ccflow.context.NullContext'> must match the type of the context accepted by __call__ <class 'ccflow.tests.test_callable.MyContext'>"
error = "The context_type <class 'ccflow.tests.test_callable.MyOtherContext'> must match the type of the context accepted by __call__ <class 'ccflow.tests.test_callable.MyContext'>"
self.assertRaisesRegex(ValueError, error, BadModelMismatchedContextAndCall)

error = "The result_type <class 'ccflow.result.generic.GenericResult'> must match the return type of __call__ <class 'ccflow.tests.test_callable.MyResult'>"
Expand Down Expand Up @@ -642,7 +646,7 @@ def __call__(self, context: NullContext) -> GenericResult[float]:
MyCallable()

def test_types_generic(self):
error = "Context type annotation <class 'ccflow.tests.test_callable.MyContext'> on __call__ does not match context_type <class 'ccflow.context.NullContext'> defined by CallableModelGenericType"
error = "Context type annotation <class 'ccflow.tests.test_callable.MyContext'> on __call__ does not match context_type <class 'ccflow.tests.test_callable.MyOtherContext'> defined by CallableModelGenericType"
self.assertRaisesRegex(TypeError, error, BadModelGenericMismatchedContextAndCall)

error = "Return type annotation <class 'ccflow.tests.test_callable.MyResult'> on __call__ does not match result_type <class 'ccflow.result.generic.GenericResult'> defined by CallableModelGenericType"
Expand Down
21 changes: 19 additions & 2 deletions ccflow/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
from ccflow.result import GenericResult


class MyDefaultContext(ContextBase):
b: float = 3.14
c: bool = False


class TestContexts(TestCase):
def test_null_context(self):
n1 = NullContext()
Expand All @@ -38,11 +43,25 @@ def test_null_context(self):
self.assertEqual(hash(n1), hash(n2))

def test_null_context_validation(self):
# Context creation is based on two main assumptions:
# 1. If there is enough information to create a context, it should be created.
# 2. Since NullContext has no required fields, it can be created from None,
# empty containers ({} or []), or any other context.
self.assertEqual(NullContext.model_validate([]), NullContext())
self.assertEqual(NullContext.model_validate({}), NullContext())
self.assertEqual(NullContext.model_validate(None), NullContext())
self.assertIsInstance(NullContext.model_validate(DateContext(date="0d")), NullContext)
self.assertRaises(ValueError, NullContext.model_validate, [True])

def test_context_with_defaults(self):
# Contexts may define default values. Extending the assumptions above:
# Any context inherits the behavior from NullContext, and can be
# created as long as all required fields (if any) are satisfied.
self.assertEqual(TypeAdapter(MyDefaultContext).validate_python(None), MyDefaultContext(b=3.14, c=False))
self.assertEqual(TypeAdapter(MyDefaultContext).validate_python({}), MyDefaultContext(b=3.14, c=False))
self.assertEqual(TypeAdapter(MyDefaultContext).validate_python([]), MyDefaultContext(b=3.14, c=False))
self.assertEqual(TypeAdapter(MyDefaultContext).validate_python({"b": 10.0}), MyDefaultContext(b=10.0, c=False))

def test_date_validation(self):
c = DateContext(date=date.today())
self.assertEqual(DateContext(date=str(date.today())), c)
Expand Down Expand Up @@ -228,8 +247,6 @@ def setUp(self):
for name, obj in inspect.getmembers(ctx, inspect.isclass)
if obj.__module__ == ctx.__name__ and issubclass(obj, ContextBase) and not getattr(obj, "__deprecated__", False)
}
# TODO - remove NullContext until we fix the inheritance
self.classes.pop("NullContext")

def test_field_ordering(self):
"""Test that complex contexts have fields in the same order as the basic contexts they are composed of."""
Expand Down