Coverage for /pythoncovmergedfiles/medio/medio/src/pydantic/pydantic/_internal/_decorators.py: 49%
227 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-27 07:38 +0000
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-27 07:38 +0000
1"""
2Logic related to validators applied to models etc. via the `@validator` and `@root_validator` decorators.
3"""
4from __future__ import annotations as _annotations
6from dataclasses import dataclass, field
7from functools import partial, partialmethod
8from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
9from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeVar, Union, cast
11from pydantic_core import core_schema
12from typing_extensions import Literal, TypeAlias
14from ..errors import PydanticUserError
15from ..fields import ComputedFieldInfo
16from ._core_utils import get_type_ref
17from ._internal_dataclass import slots_dataclass
19if TYPE_CHECKING:
20 from ..decorators import FieldValidatorModes
22try:
23 from functools import cached_property # type: ignore
24except ImportError:
25 # python 3.7
26 cached_property = None
29@slots_dataclass
30class ValidatorDecoratorInfo:
31 """
32 A container for data from `@validator` so that we can access it
33 while building the pydantic-core schema.
34 """
36 decorator_repr: ClassVar[str] = '@validator'
38 fields: tuple[str, ...]
39 mode: Literal['before', 'after']
40 each_item: bool
41 always: bool
42 check_fields: bool | None
45@slots_dataclass
46class FieldValidatorDecoratorInfo:
47 """
48 A container for data from `@field_validator` so that we can access it
49 while building the pydantic-core schema.
50 """
52 decorator_repr: ClassVar[str] = '@field_validator'
54 fields: tuple[str, ...]
55 mode: FieldValidatorModes
56 check_fields: bool | None
59@slots_dataclass
60class RootValidatorDecoratorInfo:
61 """
62 A container for data from `@root_validator` so that we can access it
63 while building the pydantic-core schema.
64 """
66 decorator_repr: ClassVar[str] = '@root_validator'
67 mode: Literal['before', 'after']
70@slots_dataclass
71class FieldSerializerDecoratorInfo:
72 """
73 A container for data from `@field_serializer` so that we can access it
74 while building the pydantic-core schema.
75 """
77 decorator_repr: ClassVar[str] = '@field_serializer'
78 fields: tuple[str, ...]
79 mode: Literal['plain', 'wrap']
80 json_return_type: core_schema.JsonReturnTypes | None
81 when_used: core_schema.WhenUsed
82 check_fields: bool | None
85@slots_dataclass
86class ModelSerializerDecoratorInfo:
87 """
88 A container for data from `@model_serializer` so that we can access it
89 while building the pydantic-core schema.
90 """
92 decorator_repr: ClassVar[str] = '@model_serializer'
93 mode: Literal['plain', 'wrap']
94 json_return_type: core_schema.JsonReturnTypes | None
97@slots_dataclass
98class ModelValidatorDecoratorInfo:
99 """
100 A container for data from `@model_validator` so that we can access it
101 while building the pydantic-core schema.
102 """
104 decorator_repr: ClassVar[str] = '@model_validator'
105 mode: Literal['wrap', 'before', 'after']
108DecoratorInfo = Union[
109 ValidatorDecoratorInfo,
110 FieldValidatorDecoratorInfo,
111 RootValidatorDecoratorInfo,
112 FieldSerializerDecoratorInfo,
113 ModelSerializerDecoratorInfo,
114 ModelValidatorDecoratorInfo,
115 ComputedFieldInfo,
116]
118ReturnType = TypeVar('ReturnType')
119DecoratedType: TypeAlias = (
120 'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]'
121)
124@dataclass # can't use slots here since we set attributes on `__post_init__`
125class PydanticDescriptorProxy(Generic[ReturnType]):
126 """
127 Wrap a classmethod, staticmethod, property or unbound function
128 and act as a descriptor that allows us to detect decorated items
129 from the class' attributes.
131 This class' __get__ returns the wrapped item's __get__ result,
132 which makes it transparent for classmethods and staticmethods.
133 """
135 wrapped: DecoratedType[ReturnType]
136 decorator_info: DecoratorInfo
137 shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None
139 def __post_init__(self):
140 for attr in 'setter', 'deleter':
141 if hasattr(self.wrapped, attr):
142 f = partial(self._call_wrapped_attr, name=attr)
143 setattr(self, attr, f)
145 def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
146 self.wrapped = getattr(self.wrapped, name)(func)
147 return self
149 def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
150 try:
151 return self.wrapped.__get__(obj, obj_type)
152 except AttributeError:
153 # not a descriptor, e.g. a partial object
154 return self.wrapped # type: ignore[return-value]
156 def __set_name__(self, instance: Any, name: str) -> None:
157 if hasattr(self.wrapped, '__set_name__'):
158 self.wrapped.__set_name__(instance, name)
160 def __getattr__(self, __name: str) -> Any:
161 """Forward checks for __isabstractmethod__ and such"""
162 return getattr(self.wrapped, __name)
165DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
168@slots_dataclass
169class Decorator(Generic[DecoratorInfoType]):
170 """
171 A generic container class to join together the decorator metadata
172 (metadata from decorator itself, which we have when the
173 decorator is called but not when we are building the core-schema)
174 and the bound function (which we have after the class itself is created).
175 """
177 cls_ref: str
178 cls_var_name: str
179 func: Callable[..., Any]
180 shim: Callable[[Any], Any] | None
181 info: DecoratorInfoType
183 @staticmethod
184 def build(
185 cls_: Any,
186 *,
187 cls_var_name: str,
188 shim: Callable[[Any], Any] | None,
189 info: DecoratorInfoType,
190 ) -> Decorator[DecoratorInfoType]:
191 func = getattr(cls_, cls_var_name)
192 if shim is not None:
193 func = shim(func)
194 return Decorator(
195 cls_ref=get_type_ref(cls_),
196 cls_var_name=cls_var_name,
197 func=func,
198 shim=shim,
199 info=info,
200 )
202 def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]:
203 return self.build(
204 cls,
205 cls_var_name=self.cls_var_name,
206 shim=self.shim,
207 info=self.info,
208 )
211@slots_dataclass
212class DecoratorInfos:
213 # mapping of name in the class namespace to decorator info
214 # note that the name in the class namespace is the function or attribute name
215 # not the field name!
216 # TODO these all need to be renamed to plural
217 validator: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict)
218 field_validator: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict)
219 root_validator: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict)
220 field_serializer: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict)
221 model_serializer: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict)
222 model_validator: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict)
223 computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict)
225 @staticmethod
226 def build(model_dc: type[Any]) -> DecoratorInfos: # noqa: C901 (ignore complexity)
227 """
228 We want to collect all DecFunc instances that exist as
229 attributes in the namespace of the class (a BaseModel or dataclass)
230 that called us
231 But we want to collect these in the order of the bases
232 So instead of getting them all from the leaf class (the class that called us),
233 we traverse the bases from root (the oldest ancestor class) to leaf
234 and collect all of the instances as we go, taking care to replace
235 any duplicate ones with the last one we see to mimic how function overriding
236 works with inheritance.
237 If we do replace any functions we put the replacement into the position
238 the replaced function was in; that is, we maintain the order.
239 """
241 # reminder: dicts are ordered and replacement does not alter the order
242 res = DecoratorInfos()
243 for base in model_dc.__bases__[::-1]:
244 existing = cast(Union[DecoratorInfos, None], getattr(base, '__pydantic_decorators__', None))
245 if existing is not None:
246 res.validator.update({k: v.bind_to_cls(model_dc) for k, v in existing.validator.items()})
247 res.field_validator.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_validator.items()})
248 res.root_validator.update({k: v.bind_to_cls(model_dc) for k, v in existing.root_validator.items()})
249 res.field_serializer.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_serializer.items()})
250 res.model_serializer.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_serializer.items()})
251 res.model_validator.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_validator.items()})
252 res.computed_fields.update({k: v.bind_to_cls(model_dc) for k, v in existing.computed_fields.items()})
254 for var_name, var_value in vars(model_dc).items():
255 if isinstance(var_value, PydanticDescriptorProxy):
256 info = var_value.decorator_info
257 if isinstance(info, ValidatorDecoratorInfo):
258 res.validator[var_name] = Decorator.build(
259 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
260 )
261 elif isinstance(info, FieldValidatorDecoratorInfo):
262 res.field_validator[var_name] = Decorator.build(
263 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
264 )
265 elif isinstance(info, RootValidatorDecoratorInfo):
266 res.root_validator[var_name] = Decorator.build(
267 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
268 )
269 elif isinstance(info, FieldSerializerDecoratorInfo):
270 # check whether a serializer function is already registered for fields
271 for field_serializer_decorator in res.field_serializer.values():
272 # check that each field has at most one serializer function.
273 # serializer functions for the same field in subclasses are allowed,
274 # and are treated as overrides
275 if field_serializer_decorator.cls_var_name == var_name:
276 continue
277 for f in info.fields:
278 if f in field_serializer_decorator.info.fields:
279 raise PydanticUserError(
280 'Multiple field serializer functions were defined '
281 f'for field {f!r}, this is not allowed.',
282 code='multiple-field-serializers',
283 )
284 res.field_serializer[var_name] = Decorator.build(
285 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
286 )
287 elif isinstance(info, ModelValidatorDecoratorInfo):
288 res.model_validator[var_name] = Decorator.build(
289 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
290 )
291 elif isinstance(info, ModelSerializerDecoratorInfo):
292 res.model_serializer[var_name] = Decorator.build(
293 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
294 )
295 else:
296 isinstance(var_value, ComputedFieldInfo)
297 res.computed_fields[var_name] = Decorator.build(
298 model_dc, cls_var_name=var_name, shim=None, info=info
299 )
300 setattr(model_dc, var_name, var_value.wrapped)
301 return res
304def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool:
305 """
306 Look at a field or model validator function and determine if it whether it takes an info argument.
308 An error is raised if the function has an invalid signature.
310 Args:
311 validator: The validator function to inspect.
312 mode: The proposed validator mode.
314 Returns:
315 Whether the validator takes an info argument.
316 """
317 sig = signature(validator)
318 n_positional = count_positional_params(sig)
319 if mode == 'wrap':
320 if n_positional == 3:
321 return True
322 elif n_positional == 2:
323 return False
324 else:
325 assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain"
326 if n_positional == 2:
327 return True
328 elif n_positional == 1:
329 return False
331 raise PydanticUserError(
332 f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}',
333 code='field-validator-signature',
334 )
337def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
338 """
339 Look at a field serializer function and determine if it is a field serializer,
340 and whether it takes an info argument.
342 An error is raised if the function has an invalid signature.
344 Args:
345 serializer: The serializer function to inspect.
346 mode: The serializer mode, either 'plain' or 'wrap'.
348 Returns:
349 Tuple of (is_field_serializer, info_arg)
350 """
351 sig = signature(serializer)
353 first = next(iter(sig.parameters.values()), None)
354 is_field_serializer = first is not None and first.name == 'self'
356 n_positional = count_positional_params(sig)
357 if is_field_serializer:
358 # -1 to correct for self parameter
359 info_arg = _serializer_info_arg(mode, n_positional - 1)
360 else:
361 info_arg = _serializer_info_arg(mode, n_positional)
363 if info_arg is None:
364 raise PydanticUserError(
365 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
366 code='field-serializer-signature',
367 )
368 else:
369 return is_field_serializer, info_arg
372def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
373 """
374 Look at a serializer function used via `Annotated` and determine whether it takes an info argument.
376 An error is raised if the function has an invalid signature.
378 Args:
379 serializer: The serializer function to check.
380 mode: The serializer mode, either 'plain' or 'wrap'.
382 Returns:
383 info_arg
384 """
385 sig = signature(serializer)
386 info_arg = _serializer_info_arg(mode, count_positional_params(sig))
387 if info_arg is None:
388 raise PydanticUserError(
389 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
390 code='field-serializer-signature',
391 )
392 else:
393 return info_arg
396def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
397 """
398 Look at a model serializer function and determine whether it takes an info argument.
400 An error is raised if the function has an invalid signature.
402 Args:
403 serializer: The serializer function to check.
404 mode: The serializer mode, either 'plain' or 'wrap'.
406 Returns:
407 `info_arg` - whether the function expects an info argument
408 """
410 if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer):
411 raise PydanticUserError(
412 '`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method'
413 )
415 sig = signature(serializer)
416 info_arg = _serializer_info_arg(mode, count_positional_params(sig))
417 if info_arg is None:
418 raise PydanticUserError(
419 f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
420 code='model-serializer-signature',
421 )
422 else:
423 return info_arg
426def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
427 if mode == 'plain':
428 if n_positional == 1:
429 # (__input_value: Any) -> Any
430 return False
431 elif n_positional == 2:
432 # (__model: Any, __input_value: Any) -> Any
433 return True
434 else:
435 assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
436 if n_positional == 2:
437 # (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any
438 return False
439 elif n_positional == 3:
440 # (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any
441 return True
443 return None
446AnyDecoratorCallable: TypeAlias = (
447 'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]'
448)
451def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool:
452 sig = signature(unwrap_wrapped_function(function))
453 first = next(iter(sig.parameters.values()), None)
454 if first and first.name == 'self':
455 return True
456 return False
459def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any:
460 if not isinstance(
461 unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod
462 ) and _is_classmethod_from_sig(function):
463 return classmethod(function) # type: ignore[arg-type]
464 return function
467def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool:
468 sig = signature(unwrap_wrapped_function(function))
469 first = next(iter(sig.parameters.values()), None)
470 if first and first.name == 'cls':
471 return True
472 return False
475def unwrap_wrapped_function(
476 func: Any,
477 *,
478 unwrap_class_static_method: bool = True,
479) -> Any:
480 """
481 Recursively unwraps a wrapped function until the underlying function is reached.
482 This handles functools.partial, functools.partialmethod, staticmethod and classmethod.
484 Args:
485 func: The function to unwrap.
486 unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
487 decorators. If False, only unwrap partial and partialmethod decorators.
489 Returns:
490 The underlying function of the wrapped function.
491 """
492 all: tuple[Any, ...]
493 if unwrap_class_static_method:
494 all = (
495 staticmethod,
496 classmethod,
497 partial,
498 partialmethod,
499 )
500 else:
501 all = partial, partialmethod
503 while isinstance(func, all):
504 if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
505 func = func.__func__
506 elif isinstance(func, (partial, partialmethod)):
507 func = func.func
509 return func
512def count_positional_params(sig: Signature) -> int:
513 return sum(1 for param in sig.parameters.values() if can_be_positional(param))
516def can_be_positional(param: Parameter) -> bool:
517 return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
520def ensure_property(f: Any) -> Any:
521 """
522 Ensure that a function is a `property` or `cached_property`, or is a valid descriptor.
524 Args:
525 f: The function to check.
527 Returns:
528 The function, or a `property` or `cached_property` instance wrapping the function.
529 """
531 if ismethoddescriptor(f) or isdatadescriptor(f):
532 return f
533 else:
534 return property(f)