Skip to content

Commit 50fa417

Browse files
committed
Support class decoration for params and autoparams
1 parent 87c6496 commit 50fa417

File tree

3 files changed

+116
-22
lines changed

3 files changed

+116
-22
lines changed

src/inject/__init__.py

Lines changed: 84 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,8 @@ def injection_wrapper(*args: t.Any, **kwargs: t.Any) -> T: # noqa: ANN401
391391
class _ParametersInjection(t.Generic[T]):
392392
__slots__ = ("_params",)
393393

394-
def __init__(self, **kwargs: Binding) -> None:
395-
self._params = kwargs
394+
def __init__(self, *, parameters: dict[str, Binding]) -> None:
395+
self._params = parameters
396396

397397
@staticmethod
398398
def _aggregate_sync_stack(
@@ -431,7 +431,8 @@ async def _aggregate_async_stack(
431431
kwargs.update(executed_kwargs)
432432

433433
def __call__(
434-
self, func: t.Callable[..., t.Union[t.Awaitable[T], T]]
434+
self,
435+
func: t.Callable[..., t.Union[t.Awaitable[T], T]],
435436
) -> t.Callable[..., t.Union[t.Awaitable[T], T]]:
436437
arg_names = inspect.getfullargspec(func).args
437438
params_to_provide = self._params
@@ -626,7 +627,11 @@ def param(name: str, cls: t.Optional[Binding] = None) -> t.Callable:
626627
return _ParameterInjection(name, cls)
627628

628629

629-
def params(**args_to_classes: Binding) -> t.Callable:
630+
def params(
631+
method_name: t.Union[str, _MISSING] = _MISSING,
632+
/,
633+
**args_to_classes: Binding,
634+
) -> t.Callable:
630635
"""
631636
Return a decorator which injects args into a function.
632637
@@ -635,8 +640,30 @@ def params(**args_to_classes: Binding) -> t.Callable:
635640
@inject.params(cache=RedisCache, db=DbInterface)
636641
def sign_up(name, email, cache, db):
637642
pass
643+
644+
Raises:
645+
ValueError: on invalid arguments
646+
638647
"""
639-
return _ParametersInjection(**args_to_classes)
648+
if not args_to_classes:
649+
raise ValueError("Params kwargs can't be empty")
650+
651+
def params_decorator(cls_or_func: t.Callable[..., T]) -> t.Callable[..., T]:
652+
fn, cls, m_name = _parse_cls_or_fn(cls_or_func, method_name)
653+
654+
wrapper: _ParametersInjection[T] = _ParametersInjection(
655+
parameters=args_to_classes
656+
)
657+
wrapped = wrapper(fn)
658+
659+
if not cls:
660+
return wrapped
661+
662+
setattr(cls_or_func, m_name, wrapped)
663+
return cls_or_func
664+
665+
return params_decorator
666+
# return _ParametersInjection(parameters=args_to_classes)
640667

641668

642669
# NOTE(pyctrl): only since 3.12
@@ -654,11 +681,14 @@ def autoparams(fn: t.Callable) -> t.Callable: ...
654681

655682

656683
@t.overload
657-
def autoparams(*selected: str) -> t.Callable: ...
684+
def autoparams(
685+
*selected: str,
686+
method_name: t.Union[str, _MISSING] = _MISSING,
687+
) -> t.Callable: ...
658688

659689

660690
@t.no_type_check
661-
def autoparams(*selected: str):
691+
def autoparams(*selected: str, method_name: t.Union[str, _MISSING] = _MISSING):
662692
"""
663693
Return a decorator injecting args based on function type hints, only since 3.5.
664694
@@ -679,24 +709,39 @@ def sign_up(name, email, cache: RedisCache, db: DbInterface):
679709
"""
680710
only_these: set[str] = set()
681711

682-
def autoparams_decorator(fn: t.Callable[..., T]) -> t.Callable[..., T]:
683-
if inspect.isclass(fn):
684-
types = t.get_type_hints(fn.__init__)
685-
else:
686-
types = t.get_type_hints(fn)
687-
688-
# Skip the return annotation.
689-
types = {name: typ for name, typ in types.items() if name != _RETURN}
712+
def autoparams_decorator(cls_or_func: t.Callable[..., T]) -> t.Callable[..., T]:
713+
# nonlocal method_name
714+
# is_class = inspect.isclass(cls_or_func)
715+
# if is_class:
716+
# if method_name is _MISSING:
717+
# method_name = "__init__"
718+
# fn = getattr(cls_or_func, method_name)
719+
# elif method_name is not _MISSING:
720+
# raise TypeError("You can't provide method name with function")
721+
# else:
722+
# fn = cls_or_func
723+
724+
fn, cls, m_name = _parse_cls_or_fn(cls_or_func, method_name)
725+
type_hints = t.get_type_hints(fn)
726+
727+
allowlist = set(only_these or type_hints)
728+
allowlist.discard(_RETURN) # Skip the return annotation.
729+
730+
parameters = {
731+
# Convert Union types into single types, i.e. Union[A, None] => A.
732+
name: _unwrap_union_arg(typ)
733+
for name, typ in type_hints.items()
734+
if name in allowlist
735+
}
690736

691-
# Convert Union types into single types, i.e. Union[A, None] => A.
692-
types = {name: _unwrap_union_arg(typ) for name, typ in types.items()}
737+
wrapper: _ParametersInjection[T] = _ParametersInjection(parameters=parameters)
738+
wrapped = wrapper(fn)
693739

694-
# Filter types if selected args present.
695-
if only_these:
696-
types = {name: typ for name, typ in types.items() if name in only_these}
740+
if not cls:
741+
return wrapped
697742

698-
wrapper: _ParametersInjection[T] = _ParametersInjection(**types)
699-
return wrapper(fn)
743+
setattr(cls_or_func, m_name, wrapped)
744+
return cls_or_func
700745

701746
target = selected[0] if selected else None
702747
if len(selected) == 1 and callable(target):
@@ -770,3 +815,20 @@ def _unwrap_cls_annotation(cls: type, attr_name: str) -> type:
770815
raise InjectorException(msg) from None
771816

772817
return _unwrap_union_arg(attr_type)
818+
819+
820+
def _parse_cls_or_fn(
821+
target: t.Callable,
822+
method_name: t.Union[str, _MISSING],
823+
) -> tuple[t.Callable, t.Optional[type], t.Optional[str]]:
824+
is_class = inspect.isclass(target)
825+
if is_class:
826+
if method_name is _MISSING:
827+
method_name = "__init__"
828+
fn = getattr(target, method_name)
829+
return fn, target, method_name
830+
831+
if method_name is not _MISSING:
832+
raise ValueError("Providing 'method_name' argument is forbidden for functions")
833+
834+
return target, None, None

tests/test_autoparams.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,19 @@ def config(binder):
242242

243243
self.assertRaises(TypeError, test_func)
244244
self.assertRaises(TypeError, test_func, a=1, c=3)
245+
246+
def test_autoparams_on_cls(self):
247+
@inject.autoparams
248+
@inject.autoparams(method_name="method")
249+
class MyClass:
250+
def __init__(self, val: int):
251+
self.val = val
252+
253+
def method(self, val: int):
254+
return val
255+
256+
inject.configure(lambda binder: binder.bind(int, 123))
257+
obj = MyClass()
258+
259+
assert obj.val == 123
260+
assert obj.method() == 123

tests/test_params.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,19 @@ async def test_func(val): # noqa: RUF029
148148
assert self.run_async(test_func()) == 123
149149
assert self.run_async(test_func(321)) == 321
150150
assert self.run_async(test_func(val=42)) == 42
151+
152+
def test_params_on_cls(self):
153+
@inject.params(val=int)
154+
@inject.params("method", val=int)
155+
class MyClass:
156+
def __init__(self, val: int):
157+
self.val = val
158+
159+
def method(self, val: int):
160+
return val
161+
162+
inject.configure(lambda binder: binder.bind(int, 123))
163+
obj = MyClass()
164+
165+
assert obj.val == 123
166+
assert obj.method() == 123

0 commit comments

Comments
 (0)