1
1
from functools import wraps
2
- from typing import TYPE_CHECKING , Any , Dict , List , Mapping , Tuple , TypeVar , cast , get_type_hints
2
+ from typing import TYPE_CHECKING , Any , Dict , List , Mapping , Tuple , Type , TypeVar , Union , cast , get_type_hints
3
3
4
4
from . import validator
5
5
from .errors import ConfigError
12
12
from .typing import AnyCallable
13
13
14
14
Callable = TypeVar ('Callable' , bound = AnyCallable )
15
+ ConfigType = Union [None , Type [Any ], Dict [str , Any ]]
15
16
16
17
17
- def validate_arguments (func : 'Callable' = None , ** config_params : Any ) -> 'Callable' :
18
+ def validate_arguments (func : 'Callable' = None , * , config : 'ConfigType' = None ) -> 'Callable' :
18
19
"""
19
20
Decorator to validate the arguments passed to a function.
20
21
"""
21
22
22
23
def validate (_func : 'Callable' ) -> 'Callable' :
23
- vd = ValidatedFunction (_func , ** config_params )
24
+ vd = ValidatedFunction (_func , config )
24
25
25
26
@wraps (_func )
26
27
def wrapper_function (* args : Any , ** kwargs : Any ) -> Any :
@@ -43,7 +44,7 @@ def wrapper_function(*args: Any, **kwargs: Any) -> Any:
43
44
44
45
45
46
class ValidatedFunction :
46
- def __init__ (self , function : 'Callable' , ** config_params : Any ):
47
+ def __init__ (self , function : 'Callable' , config : 'ConfigType' ):
47
48
from inspect import signature , Parameter
48
49
49
50
parameters : Mapping [str , Parameter ] = signature (function ).parameters
@@ -107,7 +108,7 @@ def __init__(self, function: 'Callable', **config_params: Any):
107
108
# same with kwargs
108
109
fields [self .v_kwargs_name ] = Dict [Any , Any ], None
109
110
110
- self .create_model (fields , takes_args , takes_kwargs , ** config_params )
111
+ self .create_model (fields , takes_args , takes_kwargs , config )
111
112
112
113
def call (self , * args : Any , ** kwargs : Any ) -> Any :
113
114
values = self .build_values (args , kwargs )
@@ -177,16 +178,17 @@ def execute(self, m: BaseModel) -> Any:
177
178
else :
178
179
return self .raw_function (** d )
179
180
180
- def create_model (self , fields : Dict [str , Any ], takes_args : bool , takes_kwargs : bool , ** config_params : Any ) -> None :
181
+ def create_model (self , fields : Dict [str , Any ], takes_args : bool , takes_kwargs : bool , config : 'ConfigType' ) -> None :
181
182
pos_args = len (self .arg_mapping )
182
183
183
- if TYPE_CHECKING :
184
+ class CustomConfig :
185
+ pass
184
186
185
- class CustomConfig :
186
- pass
187
-
188
- else :
189
- CustomConfig = type ( 'Config' , (), config_params )
187
+ if not TYPE_CHECKING :
188
+ if isinstance ( config , dict ):
189
+ CustomConfig = type ( 'Config' , (), config )
190
+ elif config is not None :
191
+ CustomConfig = config
190
192
191
193
class DecoratorBaseModel (BaseModel ):
192
194
@validator (self .v_args_name , check_fields = False , allow_reuse = True )
0 commit comments