Skip to content
179 changes: 164 additions & 15 deletions ccflow/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
"""

import abc
import inspect
import logging
from functools import lru_cache, wraps
from inspect import Signature, isclass, signature
from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin

from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator
from typing_extensions import override
Expand All @@ -27,6 +28,7 @@
ResultBase,
ResultType,
)
from .local_persistence import create_ccflow_model
from .validators import str_to_log_level

__all__ = (
Expand Down Expand Up @@ -268,14 +270,31 @@ def get_evaluation_context(model: CallableModelType, context: ContextType, as_di
def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = None, **kwargs):
if not isinstance(model, CallableModel):
raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.")
if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not (
get_origin(model.context_type) is Union and type(None) in get_args(model.context_type)
):
raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase")
if (not isclass(model.result_type) or not issubclass(model.result_type, ResultBase)) and not (
get_origin(model.result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(model.result_type))

# Check if this is an auto_context decorated method
has_auto_context = hasattr(fn, "__auto_context__")
if has_auto_context:
method_context_type = fn.__auto_context__
else:
method_context_type = model.context_type

# Validate context type (skip for auto contexts which are always valid ContextBase subclasses)
if not has_auto_context:
if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not (
get_origin(model.context_type) is Union and type(None) in get_args(model.context_type)
):
raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase")

# Validate result type - use __result_type__ for auto contexts if available
if has_auto_context and hasattr(fn, "__result_type__"):
method_result_type = fn.__result_type__
else:
method_result_type = model.result_type
if (not isclass(method_result_type) or not issubclass(method_result_type, ResultBase)) and not (
get_origin(method_result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(method_result_type))
):
raise TypeError(f"Result type {model.result_type} must be a subclass of ResultBase")
raise TypeError(f"Result type {method_result_type} must be a subclass of ResultBase")

if self._deps and fn.__name__ != "__deps__":
raise ValueError("Can only apply Flow.deps decorator to __deps__")
if context is Signature.empty:
Expand All @@ -285,18 +304,18 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
context = kwargs
else:
raise TypeError(
f"{fn.__name__}() missing 1 required positional argument: 'context' of type {model.context_type}, or kwargs to construct it"
f"{fn.__name__}() missing 1 required positional argument: 'context' of type {method_context_type}, or kwargs to construct it"
)
elif kwargs: # Kwargs passed in as well as context. Not allowed
raise TypeError(f"{fn.__name__}() was passed a context and got an unexpected keyword argument '{next(iter(kwargs.keys()))}'")

# Type coercion on input. We do this here (rather than relying on ModelEvaluationContext) as it produces a nicer traceback/error message
if not isinstance(context, model.context_type):
if get_origin(model.context_type) is Union and type(None) in get_args(model.context_type):
model_context_type = [t for t in get_args(model.context_type) if t is not type(None)][0]
if not isinstance(context, method_context_type):
if get_origin(method_context_type) is Union and type(None) in get_args(method_context_type):
coerce_context_type = [t for t in get_args(method_context_type) if t is not type(None)][0]
else:
model_context_type = model.context_type
context = model_context_type.model_validate(context)
coerce_context_type = method_context_type
context = coerce_context_type.model_validate(context)

if fn != getattr(model.__class__, fn.__name__).__wrapped__:
# This happens when super().__call__ is used when implementing a CallableModel that derives from another one.
Expand All @@ -313,6 +332,13 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
wrap.get_evaluator = self.get_evaluator
wrap.get_options = self.get_options
wrap.get_evaluation_context = get_evaluation_context

# Preserve auto context attributes for introspection
if hasattr(fn, "__auto_context__"):
wrap.__auto_context__ = fn.__auto_context__
if hasattr(fn, "__result_type__"):
wrap.__result_type__ = fn.__result_type__

return wrap


Expand Down Expand Up @@ -391,7 +417,58 @@ def __exit__(self, exc_type, exc_value, exc_tb):
class Flow(PydanticBaseModel):
@staticmethod
def call(*args, **kwargs):
"""Decorator for methods on callable models"""
"""Decorator for methods on callable models.

Args:
auto_context: Controls automatic context class generation from the function
signature. Accepts three types of values:
- False (default): No auto-generation, use traditional context parameter
- True: Auto-generate context class with no parent
- ContextBase subclass: Auto-generate context class inheriting from this parent
**kwargs: Additional FlowOptions parameters (log_level, verbose, validate_result,
cacheable, evaluator, volatile).

Basic Example:
class MyModel(CallableModel):
@Flow.call
def __call__(self, context: MyContext) -> MyResult:
return MyResult(value=context.x)

Auto Context Example:
class MyModel(CallableModel):
@Flow.call(auto_context=True)
def __call__(self, *, x: int, y: str = "default") -> MyResult:
return MyResult(value=f"{x}-{y}")

model = MyModel()
model(x=42) # Call with kwargs directly

With Parent Context:
class MyModel(CallableModel):
@Flow.call(auto_context=DateContext)
def __call__(self, *, date: date, extra: int = 0) -> MyResult:
return MyResult(value=date.day + extra)

# The generated context inherits from DateContext, so it's compatible
# with infrastructure expecting DateContext instances.
"""
# Extract auto_context option (not part of FlowOptions)
# Can be: False, True, or a ContextBase subclass
auto_context = kwargs.pop("auto_context", False)

# Determine if auto_context is enabled and extract parent class if provided
if auto_context is False:
auto_context_enabled = False
context_parent = None
elif auto_context is True:
auto_context_enabled = True
context_parent = None
elif isclass(auto_context) and issubclass(auto_context, ContextBase):
auto_context_enabled = True
context_parent = auto_context
else:
raise TypeError(f"auto_context must be False, True, or a ContextBase subclass, got {auto_context!r}")

if len(args) == 1 and callable(args[0]):
# No arguments to decorator, this is the decorator
fn = args[0]
Expand All @@ -400,6 +477,14 @@ def call(*args, **kwargs):
else:
# Arguments to decorator, this is just returning the decorator
# Note that the code below is executed only once
if auto_context_enabled:
# Return a decorator that first applies auto_context, then FlowOptions
def auto_context_decorator(fn):
wrapped = _apply_auto_context(fn, parent=context_parent)
# FlowOptions.__call__ already applies wraps, so we just return its result
return FlowOptions(**kwargs)(wrapped)

return auto_context_decorator
return FlowOptions(**kwargs)

@staticmethod
Expand Down Expand Up @@ -754,3 +839,67 @@ def _validate_callable_model_generic_type(cls, m, handler, info):


CallableModelGenericType = CallableModelGeneric


# *****************************************************************************
# Auto Context (internal helper for Flow.call(auto_context=True))
# *****************************************************************************


def _apply_auto_context(func: Callable, *, parent: Type[ContextBase] = None) -> Callable:
"""Internal function that creates an auto context class from function parameters.

This function extracts the parameters from a function signature and creates
a ContextBase subclass whose fields correspond to those parameters.
The decorated function is then wrapped to accept the context object and
unpack it into keyword arguments.

Used internally by Flow.call(auto_context=...).

Example:
class MyCallable(CallableModel):
@Flow.call(auto_context=True)
def __call__(self, *, x: int, y: str = "default") -> GenericResult:
return GenericResult(value=f"{x}-{y}")

model = MyCallable()
model(x=42, y="hello") # Works with kwargs
"""
sig = signature(func)
base_class = parent or ContextBase

# Validate parent fields are in function signature
if parent is not None:
parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys())
sig_params = set(sig.parameters.keys()) - {"self"}
missing = parent_fields - sig_params
if missing:
raise TypeError(f"Parent context fields {missing} must be included in function signature")

# Build fields from parameters (skip 'self'), pydantic validates types
fields = {}
for name, param in sig.parameters.items():
if name == "self":
continue
default = ... if param.default is inspect.Parameter.empty else param.default
fields[name] = (param.annotation, default)

# Create auto context class
auto_context_class = create_ccflow_model(f"{func.__qualname__}_AutoContext", __base__=base_class, **fields)

@wraps(func)
def wrapper(self, context):
fn_kwargs = {name: getattr(context, name) for name in fields}
return func(self, **fn_kwargs)

# Must set __signature__ so CallableModel validation sees 'context' parameter
wrapper.__signature__ = inspect.Signature(
parameters=[
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD),
inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=auto_context_class),
],
return_annotation=sig.return_annotation,
)
wrapper.__auto_context__ = auto_context_class
wrapper.__result_type__ = sig.return_annotation
return wrapper
Loading
Loading