@@ -391,8 +391,8 @@ def injection_wrapper(*args: t.Any, **kwargs: t.Any) -> T: # noqa: ANN401
391391class _ParametersInjection (t .Generic [T ]):
392392 __slots__ = ("_params" ,)
393393
394- def __init__ (self , ** kwargs : Binding ) -> None :
395- self ._params = kwargs
394+ def __init__ (self , * , parameters : dict [ str , Binding ] ) -> None :
395+ self ._params = parameters
396396
397397 @staticmethod
398398 def _aggregate_sync_stack (
@@ -431,7 +431,8 @@ async def _aggregate_async_stack(
431431 kwargs .update (executed_kwargs )
432432
433433 def __call__ (
434- self , func : t .Callable [..., t .Union [t .Awaitable [T ], T ]]
434+ self ,
435+ func : t .Callable [..., t .Union [t .Awaitable [T ], T ]],
435436 ) -> t .Callable [..., t .Union [t .Awaitable [T ], T ]]:
436437 arg_names = inspect .getfullargspec (func ).args
437438 params_to_provide = self ._params
@@ -626,7 +627,11 @@ def param(name: str, cls: t.Optional[Binding] = None) -> t.Callable:
626627 return _ParameterInjection (name , cls )
627628
628629
629- def params (** args_to_classes : Binding ) -> t .Callable :
630+ def params (
631+ method_name : t .Union [str , _MISSING ] = _MISSING ,
632+ / ,
633+ ** args_to_classes : Binding ,
634+ ) -> t .Callable :
630635 """
631636 Return a decorator which injects args into a function.
632637
@@ -635,8 +640,30 @@ def params(**args_to_classes: Binding) -> t.Callable:
635640 @inject.params(cache=RedisCache, db=DbInterface)
636641 def sign_up(name, email, cache, db):
637642 pass
643+
644+ Raises:
645+ ValueError: on invalid arguments
646+
638647 """
639- return _ParametersInjection (** args_to_classes )
648+ if not args_to_classes :
649+ raise ValueError ("Params kwargs can't be empty" )
650+
651+ def params_decorator (cls_or_func : t .Callable [..., T ]) -> t .Callable [..., T ]:
652+ fn , cls , m_name = _parse_cls_or_fn (cls_or_func , method_name )
653+
654+ wrapper : _ParametersInjection [T ] = _ParametersInjection (
655+ parameters = args_to_classes
656+ )
657+ wrapped = wrapper (fn )
658+
659+ if not cls :
660+ return wrapped
661+
662+ setattr (cls_or_func , m_name , wrapped )
663+ return cls_or_func
664+
665+ return params_decorator
666+ # return _ParametersInjection(parameters=args_to_classes)
640667
641668
642669# NOTE(pyctrl): only since 3.12
@@ -654,11 +681,14 @@ def autoparams(fn: t.Callable) -> t.Callable: ...
654681
655682
656683@t .overload
657- def autoparams (* selected : str ) -> t .Callable : ...
684+ def autoparams (
685+ * selected : str ,
686+ method_name : t .Union [str , _MISSING ] = _MISSING ,
687+ ) -> t .Callable : ...
658688
659689
660690@t .no_type_check
661- def autoparams (* selected : str ):
691+ def autoparams (* selected : str , method_name : t . Union [ str , _MISSING ] = _MISSING ):
662692 """
663693 Return a decorator injecting args based on function type hints, only since 3.5.
664694
@@ -679,24 +709,39 @@ def sign_up(name, email, cache: RedisCache, db: DbInterface):
679709 """
680710 only_these : set [str ] = set ()
681711
682- def autoparams_decorator (fn : t .Callable [..., T ]) -> t .Callable [..., T ]:
683- if inspect .isclass (fn ):
684- types = t .get_type_hints (fn .__init__ )
685- else :
686- types = t .get_type_hints (fn )
687-
688- # Skip the return annotation.
689- types = {name : typ for name , typ in types .items () if name != _RETURN }
712+ def autoparams_decorator (cls_or_func : t .Callable [..., T ]) -> t .Callable [..., T ]:
713+ # nonlocal method_name
714+ # is_class = inspect.isclass(cls_or_func)
715+ # if is_class:
716+ # if method_name is _MISSING:
717+ # method_name = "__init__"
718+ # fn = getattr(cls_or_func, method_name)
719+ # elif method_name is not _MISSING:
720+ # raise TypeError("You can't provide method name with function")
721+ # else:
722+ # fn = cls_or_func
723+
724+ fn , cls , m_name = _parse_cls_or_fn (cls_or_func , method_name )
725+ type_hints = t .get_type_hints (fn )
726+
727+ allowlist = set (only_these or type_hints )
728+ allowlist .discard (_RETURN ) # Skip the return annotation.
729+
730+ parameters = {
731+ # Convert Union types into single types, i.e. Union[A, None] => A.
732+ name : _unwrap_union_arg (typ )
733+ for name , typ in type_hints .items ()
734+ if name in allowlist
735+ }
690736
691- # Convert Union types into single types, i.e. Union[A, None ] => A.
692- types = { name : _unwrap_union_arg ( typ ) for name , typ in types . items ()}
737+ wrapper : _ParametersInjection [ T ] = _ParametersInjection ( parameters = parameters )
738+ wrapped = wrapper ( fn )
693739
694- # Filter types if selected args present.
695- if only_these :
696- types = {name : typ for name , typ in types .items () if name in only_these }
740+ if not cls :
741+ return wrapped
697742
698- wrapper : _ParametersInjection [ T ] = _ParametersInjection ( ** types )
699- return wrapper ( fn )
743+ setattr ( cls_or_func , m_name , wrapped )
744+ return cls_or_func
700745
701746 target = selected [0 ] if selected else None
702747 if len (selected ) == 1 and callable (target ):
@@ -770,3 +815,20 @@ def _unwrap_cls_annotation(cls: type, attr_name: str) -> type:
770815 raise InjectorException (msg ) from None
771816
772817 return _unwrap_union_arg (attr_type )
818+
819+
820+ def _parse_cls_or_fn (
821+ target : t .Callable ,
822+ method_name : t .Union [str , _MISSING ],
823+ ) -> tuple [t .Callable , t .Optional [type ], t .Optional [str ]]:
824+ is_class = inspect .isclass (target )
825+ if is_class :
826+ if method_name is _MISSING :
827+ method_name = "__init__"
828+ fn = getattr (target , method_name )
829+ return fn , target , method_name
830+
831+ if method_name is not _MISSING :
832+ raise ValueError ("Providing 'method_name' argument is forbidden for functions" )
833+
834+ return target , None , None
0 commit comments