1"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
2
3from __future__ import annotations as _annotations
4
5import types
6from collections import deque
7from collections.abc import Iterable
8from copy import copy
9from dataclasses import dataclass, field
10from functools import cached_property, partial, partialmethod
11from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor
12from itertools import islice
13from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, TypeVar, Union
14
15from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema
16from typing_extensions import Self, TypeAlias, is_typeddict
17
18from ..errors import PydanticUserError
19from ._core_utils import get_type_ref
20from ._internal_dataclass import slots_true
21from ._namespace_utils import GlobalsNamespace, MappingNamespace
22from ._typing_extra import get_function_type_hints, signature_no_eval
23from ._utils import can_be_positional
24
25if TYPE_CHECKING:
26 from ..fields import ComputedFieldInfo
27 from ..functional_validators import FieldValidatorModes
28 from ._config import ConfigWrapper
29
30
31@dataclass(**slots_true)
32class ValidatorDecoratorInfo:
33 """A container for data from `@validator` so that we can access it
34 while building the pydantic-core schema.
35
36 Attributes:
37 decorator_repr: A class variable representing the decorator string, '@validator'.
38 fields: A tuple of field names the validator should be called on.
39 mode: The proposed validator mode.
40 each_item: For complex objects (sets, lists etc.) whether to validate individual
41 elements rather than the whole object.
42 always: Whether this method and other validators should be called even if the value is missing.
43 check_fields: Whether to check that the fields actually exist on the model.
44 """
45
46 decorator_repr: ClassVar[str] = '@validator'
47
48 fields: tuple[str, ...]
49 mode: Literal['before', 'after']
50 each_item: bool
51 always: bool
52 check_fields: bool | None
53
54
55@dataclass(**slots_true)
56class FieldValidatorDecoratorInfo:
57 """A container for data from `@field_validator` so that we can access it
58 while building the pydantic-core schema.
59
60 Attributes:
61 decorator_repr: A class variable representing the decorator string, '@field_validator'.
62 fields: A tuple of field names the validator should be called on.
63 mode: The proposed validator mode.
64 check_fields: Whether to check that the fields actually exist on the model.
65 json_schema_input_type: The input type of the function. This is only used to generate
66 the appropriate JSON Schema (in validation mode) and can only specified
67 when `mode` is either `'before'`, `'plain'` or `'wrap'`.
68 """
69
70 decorator_repr: ClassVar[str] = '@field_validator'
71
72 fields: tuple[str, ...]
73 mode: FieldValidatorModes
74 check_fields: bool | None
75 json_schema_input_type: Any
76
77
78@dataclass(**slots_true)
79class RootValidatorDecoratorInfo:
80 """A container for data from `@root_validator` so that we can access it
81 while building the pydantic-core schema.
82
83 Attributes:
84 decorator_repr: A class variable representing the decorator string, '@root_validator'.
85 mode: The proposed validator mode.
86 """
87
88 decorator_repr: ClassVar[str] = '@root_validator'
89 mode: Literal['before', 'after']
90
91
92@dataclass(**slots_true)
93class FieldSerializerDecoratorInfo:
94 """A container for data from `@field_serializer` so that we can access it
95 while building the pydantic-core schema.
96
97 Attributes:
98 decorator_repr: A class variable representing the decorator string, '@field_serializer'.
99 fields: A tuple of field names the serializer should be called on.
100 mode: The proposed serializer mode.
101 return_type: The type of the serializer's return value.
102 when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
103 and `'json-unless-none'`.
104 check_fields: Whether to check that the fields actually exist on the model.
105 """
106
107 decorator_repr: ClassVar[str] = '@field_serializer'
108 fields: tuple[str, ...]
109 mode: Literal['plain', 'wrap']
110 return_type: Any
111 when_used: core_schema.WhenUsed
112 check_fields: bool | None
113
114
115@dataclass(**slots_true)
116class ModelSerializerDecoratorInfo:
117 """A container for data from `@model_serializer` so that we can access it
118 while building the pydantic-core schema.
119
120 Attributes:
121 decorator_repr: A class variable representing the decorator string, '@model_serializer'.
122 mode: The proposed serializer mode.
123 return_type: The type of the serializer's return value.
124 when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
125 and `'json-unless-none'`.
126 """
127
128 decorator_repr: ClassVar[str] = '@model_serializer'
129 mode: Literal['plain', 'wrap']
130 return_type: Any
131 when_used: core_schema.WhenUsed
132
133
134@dataclass(**slots_true)
135class ModelValidatorDecoratorInfo:
136 """A container for data from `@model_validator` so that we can access it
137 while building the pydantic-core schema.
138
139 Attributes:
140 decorator_repr: A class variable representing the decorator string, '@model_validator'.
141 mode: The proposed serializer mode.
142 """
143
144 decorator_repr: ClassVar[str] = '@model_validator'
145 mode: Literal['wrap', 'before', 'after']
146
147
148DecoratorInfo: TypeAlias = """Union[
149 ValidatorDecoratorInfo,
150 FieldValidatorDecoratorInfo,
151 RootValidatorDecoratorInfo,
152 FieldSerializerDecoratorInfo,
153 ModelSerializerDecoratorInfo,
154 ModelValidatorDecoratorInfo,
155 ComputedFieldInfo,
156]"""
157
158ReturnType = TypeVar('ReturnType')
159DecoratedType: TypeAlias = (
160 'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]'
161)
162
163
164@dataclass # can't use slots here since we set attributes on `__post_init__`
165class PydanticDescriptorProxy(Generic[ReturnType]):
166 """Wrap a classmethod, staticmethod, property or unbound function
167 and act as a descriptor that allows us to detect decorated items
168 from the class' attributes.
169
170 This class' __get__ returns the wrapped item's __get__ result,
171 which makes it transparent for classmethods and staticmethods.
172
173 Attributes:
174 wrapped: The decorator that has to be wrapped.
175 decorator_info: The decorator info.
176 shim: A wrapper function to wrap V1 style function.
177 """
178
179 wrapped: DecoratedType[ReturnType]
180 decorator_info: DecoratorInfo
181 shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None
182
183 def __post_init__(self):
184 for attr in 'setter', 'deleter':
185 if hasattr(self.wrapped, attr):
186 f = partial(self._call_wrapped_attr, name=attr)
187 setattr(self, attr, f)
188
189 def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
190 self.wrapped = getattr(self.wrapped, name)(func)
191 if isinstance(self.wrapped, property):
192 # update ComputedFieldInfo.wrapped_property
193 from ..fields import ComputedFieldInfo
194
195 if isinstance(self.decorator_info, ComputedFieldInfo):
196 self.decorator_info.wrapped_property = self.wrapped
197 return self
198
199 def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
200 try:
201 return self.wrapped.__get__(obj, obj_type) # pyright: ignore[reportReturnType]
202 except AttributeError:
203 # not a descriptor, e.g. a partial object
204 return self.wrapped # type: ignore[return-value]
205
206 def __set_name__(self, instance: Any, name: str) -> None:
207 if hasattr(self.wrapped, '__set_name__'):
208 self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
209
210 def __getattr__(self, name: str, /) -> Any:
211 """Forward checks for __isabstractmethod__ and such."""
212 return getattr(self.wrapped, name)
213
214
215DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
216
217
218@dataclass(**slots_true)
219class Decorator(Generic[DecoratorInfoType]):
220 """A generic container class to join together the decorator metadata
221 (metadata from decorator itself, which we have when the
222 decorator is called but not when we are building the core-schema)
223 and the bound function (which we have after the class itself is created).
224
225 Attributes:
226 cls_ref: The class ref.
227 cls_var_name: The decorated function name.
228 func: The decorated function.
229 shim: A wrapper function to wrap V1 style function.
230 info: The decorator info.
231 """
232
233 cls_ref: str
234 cls_var_name: str
235 func: Callable[..., Any]
236 shim: Callable[[Any], Any] | None
237 info: DecoratorInfoType
238
239 @staticmethod
240 def build(
241 cls_: Any,
242 *,
243 cls_var_name: str,
244 shim: Callable[[Any], Any] | None,
245 info: DecoratorInfoType,
246 ) -> Decorator[DecoratorInfoType]:
247 """Build a new decorator.
248
249 Args:
250 cls_: The class.
251 cls_var_name: The decorated function name.
252 shim: A wrapper function to wrap V1 style function.
253 info: The decorator info.
254
255 Returns:
256 The new decorator instance.
257 """
258 func = get_attribute_from_bases(cls_, cls_var_name)
259 if shim is not None:
260 func = shim(func)
261 func = unwrap_wrapped_function(func, unwrap_partial=False)
262 if not callable(func):
263 # TODO most likely this branch can be removed when we drop support for Python 3.12:
264 # This branch will get hit for classmethod properties
265 attribute = get_attribute_from_base_dicts(cls_, cls_var_name) # prevents the binding call to `__get__`
266 if isinstance(attribute, PydanticDescriptorProxy):
267 func = unwrap_wrapped_function(attribute.wrapped)
268 return Decorator(
269 cls_ref=get_type_ref(cls_),
270 cls_var_name=cls_var_name,
271 func=func,
272 shim=shim,
273 info=info,
274 )
275
276 def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]:
277 """Bind the decorator to a class.
278
279 Args:
280 cls: the class.
281
282 Returns:
283 The new decorator instance.
284 """
285 return self.build(
286 cls,
287 cls_var_name=self.cls_var_name,
288 shim=self.shim,
289 info=copy(self.info),
290 )
291
292
293def get_bases(tp: type[Any]) -> tuple[type[Any], ...]:
294 """Get the base classes of a class or typeddict.
295
296 Args:
297 tp: The type or class to get the bases.
298
299 Returns:
300 The base classes.
301 """
302 if is_typeddict(tp):
303 return tp.__orig_bases__ # type: ignore
304 try:
305 return tp.__bases__
306 except AttributeError:
307 return ()
308
309
310def mro(tp: type[Any]) -> tuple[type[Any], ...]:
311 """Calculate the Method Resolution Order of bases using the C3 algorithm.
312
313 See https://www.python.org/download/releases/2.3/mro/
314 """
315 # try to use the existing mro, for performance mainly
316 # but also because it helps verify the implementation below
317 if not is_typeddict(tp):
318 try:
319 return tp.__mro__
320 except AttributeError:
321 # GenericAlias and some other cases
322 pass
323
324 bases = get_bases(tp)
325 return (tp,) + mro_for_bases(bases)
326
327
328def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]:
329 def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
330 while True:
331 non_empty = [seq for seq in seqs if seq]
332 if not non_empty:
333 # Nothing left to process, we're done.
334 return
335 candidate: type[Any] | None = None
336 for seq in non_empty: # Find merge candidates among seq heads.
337 candidate = seq[0]
338 not_head = [s for s in non_empty if candidate in islice(s, 1, None)]
339 if not_head:
340 # Reject the candidate.
341 candidate = None
342 else:
343 break
344 if not candidate:
345 raise TypeError('Inconsistent hierarchy, no C3 MRO is possible')
346 yield candidate
347 for seq in non_empty:
348 # Remove candidate.
349 if seq[0] == candidate:
350 seq.popleft()
351
352 seqs = [deque(mro(base)) for base in bases] + [deque(bases)]
353 return tuple(merge_seqs(seqs))
354
355
356_sentinel = object()
357
358
359def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any:
360 """Get the attribute from the next class in the MRO that has it,
361 aiming to simulate calling the method on the actual class.
362
363 The reason for iterating over the mro instead of just getting
364 the attribute (which would do that for us) is to support TypedDict,
365 which lacks a real __mro__, but can have a virtual one constructed
366 from its bases (as done here).
367
368 Args:
369 tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes.
370 name: The name of the attribute to retrieve.
371
372 Returns:
373 Any: The attribute value, if found.
374
375 Raises:
376 AttributeError: If the attribute is not found in any class in the MRO.
377 """
378 if isinstance(tp, tuple):
379 for base in mro_for_bases(tp):
380 attribute = base.__dict__.get(name, _sentinel)
381 if attribute is not _sentinel:
382 attribute_get = getattr(attribute, '__get__', None)
383 if attribute_get is not None:
384 return attribute_get(None, tp)
385 return attribute
386 raise AttributeError(f'{name} not found in {tp}')
387 else:
388 try:
389 return getattr(tp, name)
390 except AttributeError:
391 return get_attribute_from_bases(mro(tp), name)
392
393
394def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any:
395 """Get an attribute out of the `__dict__` following the MRO.
396 This prevents the call to `__get__` on the descriptor, and allows
397 us to get the original function for classmethod properties.
398
399 Args:
400 tp: The type or class to search for the attribute.
401 name: The name of the attribute to retrieve.
402
403 Returns:
404 Any: The attribute value, if found.
405
406 Raises:
407 KeyError: If the attribute is not found in any class's `__dict__` in the MRO.
408 """
409 for base in reversed(mro(tp)):
410 if name in base.__dict__:
411 return base.__dict__[name]
412 return tp.__dict__[name] # raise the error
413
414
415@dataclass(**slots_true)
416class DecoratorInfos:
417 """Mapping of name in the class namespace to decorator info.
418
419 note that the name in the class namespace is the function or attribute name
420 not the field name!
421 """
422
423 validators: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict)
424 field_validators: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict)
425 root_validators: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict)
426 field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict)
427 model_serializers: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict)
428 model_validators: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict)
429 computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict)
430
431 @classmethod
432 def build(
433 cls,
434 typ: type[Any],
435 # Default to `True` for backwards compatibility:
436 replace_wrapped_methods: bool = True,
437 ) -> Self:
438 """Build a `DecoratorInfos` instance for the given model, dataclass or `TypedDict` type.
439
440 Decorators from parent classes are included, including "bare" classes (e.g. if `typ`
441 is a Pydantic model, non Pydantic parent model classes are also taken into account).
442 The collection of the decorators happens by respecting the MRO.
443
444 If one of the bases has an `__pydantic_decorators__` attribute set, it is assumed to be
445 a `DecoratorInfos` instance and is used as-is. The `__pydantic_decorators__` attribute
446 is *not* being set on the provided `typ`.
447
448 Args:
449 typ: The model, dataclass or `TypedDict` type to use when building the `DecoratorInfos` instance.
450 replace_wrapped_methods: Whether to replace the decorator's wrapped methods on `typ`.
451 This is useful e.g. for field validators which are initially class methods. This should
452 only be set to `True` if `typ` is a Pydantic model or dataclass (otherwise this results
453 in mutations of classes Pydantic doesn't "own").
454 """
455 # reminder: dicts are ordered and replacement does not alter the order
456 res = cls()
457 # Iterate over the bases, without the actual `typ`.
458 # `1:-1` because we don't need to include `object`/`TypedDict`:
459 for base in reversed(mro(typ)[1:-1]):
460 existing: DecoratorInfos | None = base.__dict__.get('__pydantic_decorators__')
461 if existing is None:
462 existing, _ = _decorator_infos_for_class(base, collect_to_replace=False)
463 res.validators.update({k: v.bind_to_cls(typ) for k, v in existing.validators.items()})
464 res.field_validators.update({k: v.bind_to_cls(typ) for k, v in existing.field_validators.items()})
465 res.root_validators.update({k: v.bind_to_cls(typ) for k, v in existing.root_validators.items()})
466 res.field_serializers.update({k: v.bind_to_cls(typ) for k, v in existing.field_serializers.items()})
467 res.model_serializers.update({k: v.bind_to_cls(typ) for k, v in existing.model_serializers.items()})
468 res.model_validators.update({k: v.bind_to_cls(typ) for k, v in existing.model_validators.items()})
469 res.computed_fields.update({k: v.bind_to_cls(typ) for k, v in existing.computed_fields.items()})
470
471 decorator_infos, to_replace = _decorator_infos_for_class(typ, collect_to_replace=True)
472
473 res.validators.update(decorator_infos.validators)
474 res.field_validators.update(decorator_infos.field_validators)
475 res.root_validators.update(decorator_infos.root_validators)
476 res.field_serializers.update(decorator_infos.field_serializers)
477 res.model_serializers.update(decorator_infos.model_serializers)
478 res.model_validators.update(decorator_infos.model_validators)
479 res.computed_fields.update(decorator_infos.computed_fields)
480
481 if replace_wrapped_methods and to_replace:
482 for name, value in to_replace:
483 setattr(typ, name, value)
484
485 res._validate()
486 return res
487
488 def _validate(self) -> None:
489 seen: set[str] = set()
490 for field_ser in self.field_serializers.values():
491 for f_name in field_ser.info.fields:
492 if f_name in seen:
493 raise PydanticUserError(
494 f'Multiple field serializer functions were defined for field {f_name!r}, this is not allowed.',
495 code='multiple-field-serializers',
496 )
497 seen.add(f_name)
498
499 def update_from_config(self, config_wrapper: ConfigWrapper) -> None:
500 """Update the decorator infos from the configuration of the class they are attached to."""
501 for name, computed_field_dec in self.computed_fields.items():
502 computed_field_dec.info._update_from_config(config_wrapper, name)
503
504
505def _decorator_infos_for_class(
506 typ: type[Any],
507 *,
508 collect_to_replace: bool,
509) -> tuple[DecoratorInfos, list[tuple[str, Any]]]:
510 """Collect a `DecoratorInfos` for class, without looking into bases."""
511 res = DecoratorInfos()
512 to_replace: list[tuple[str, Any]] = []
513
514 for var_name, var_value in vars(typ).items():
515 if isinstance(var_value, PydanticDescriptorProxy):
516 info = var_value.decorator_info
517 if isinstance(info, ValidatorDecoratorInfo):
518 res.validators[var_name] = Decorator.build(typ, cls_var_name=var_name, shim=var_value.shim, info=info)
519 elif isinstance(info, FieldValidatorDecoratorInfo):
520 res.field_validators[var_name] = Decorator.build(
521 typ, cls_var_name=var_name, shim=var_value.shim, info=info
522 )
523 elif isinstance(info, RootValidatorDecoratorInfo):
524 res.root_validators[var_name] = Decorator.build(
525 typ, cls_var_name=var_name, shim=var_value.shim, info=info
526 )
527 elif isinstance(info, FieldSerializerDecoratorInfo):
528 res.field_serializers[var_name] = Decorator.build(
529 typ, cls_var_name=var_name, shim=var_value.shim, info=info
530 )
531 elif isinstance(info, ModelValidatorDecoratorInfo):
532 res.model_validators[var_name] = Decorator.build(
533 typ, cls_var_name=var_name, shim=var_value.shim, info=info
534 )
535 elif isinstance(info, ModelSerializerDecoratorInfo):
536 res.model_serializers[var_name] = Decorator.build(
537 typ, cls_var_name=var_name, shim=var_value.shim, info=info
538 )
539 else:
540 from ..fields import ComputedFieldInfo
541
542 isinstance(var_value, ComputedFieldInfo)
543 res.computed_fields[var_name] = Decorator.build(typ, cls_var_name=var_name, shim=None, info=info)
544 if collect_to_replace:
545 to_replace.append((var_name, var_value.wrapped))
546
547 return res, to_replace
548
549
550def inspect_validator(
551 validator: Callable[..., Any], *, mode: FieldValidatorModes, type: Literal['field', 'model']
552) -> bool:
553 """Look at a field or model validator function and determine whether it takes an info argument.
554
555 An error is raised if the function has an invalid signature.
556
557 Args:
558 validator: The validator function to inspect.
559 mode: The proposed validator mode.
560 type: The type of validator, either 'field' or 'model'.
561
562 Returns:
563 Whether the validator takes an info argument.
564 """
565 try:
566 sig = signature_no_eval(validator)
567 except (ValueError, TypeError):
568 # `inspect.signature` might not be able to infer a signature, e.g. with C objects.
569 # In this case, we assume no info argument is present:
570 return False
571 n_positional = count_positional_required_params(sig)
572 if mode == 'wrap':
573 if n_positional == 3:
574 return True
575 elif n_positional == 2:
576 return False
577 else:
578 assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain"
579 if n_positional == 2:
580 return True
581 elif n_positional == 1:
582 return False
583
584 raise PydanticUserError(
585 f'Unrecognized {type} validator function signature for {validator} with `mode={mode}`: {sig}',
586 code='validator-signature',
587 )
588
589
590def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
591 """Look at a field serializer function and determine if it is a field serializer,
592 and whether it takes an info argument.
593
594 An error is raised if the function has an invalid signature.
595
596 Args:
597 serializer: The serializer function to inspect.
598 mode: The serializer mode, either 'plain' or 'wrap'.
599
600 Returns:
601 Tuple of (is_field_serializer, info_arg).
602 """
603 try:
604 sig = signature_no_eval(serializer)
605 except (ValueError, TypeError):
606 # `inspect.signature` might not be able to infer a signature, e.g. with C objects.
607 # In this case, we assume no info argument is present and this is not a method:
608 return (False, False)
609
610 first = next(iter(sig.parameters.values()), None)
611 is_field_serializer = first is not None and first.name == 'self'
612
613 n_positional = count_positional_required_params(sig)
614 if is_field_serializer:
615 # -1 to correct for self parameter
616 info_arg = _serializer_info_arg(mode, n_positional - 1)
617 else:
618 info_arg = _serializer_info_arg(mode, n_positional)
619
620 if info_arg is None:
621 raise PydanticUserError(
622 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
623 code='field-serializer-signature',
624 )
625
626 return is_field_serializer, info_arg
627
628
629def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
630 """Look at a serializer function used via `Annotated` and determine whether it takes an info argument.
631
632 An error is raised if the function has an invalid signature.
633
634 Args:
635 serializer: The serializer function to check.
636 mode: The serializer mode, either 'plain' or 'wrap'.
637
638 Returns:
639 info_arg
640 """
641 try:
642 sig = signature_no_eval(serializer)
643 except (ValueError, TypeError):
644 # `inspect.signature` might not be able to infer a signature, e.g. with C objects.
645 # In this case, we assume no info argument is present:
646 return False
647 info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
648 if info_arg is None:
649 raise PydanticUserError(
650 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
651 code='field-serializer-signature',
652 )
653 else:
654 return info_arg
655
656
657def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
658 """Look at a model serializer function and determine whether it takes an info argument.
659
660 An error is raised if the function has an invalid signature.
661
662 Args:
663 serializer: The serializer function to check.
664 mode: The serializer mode, either 'plain' or 'wrap'.
665
666 Returns:
667 `info_arg` - whether the function expects an info argument.
668 """
669 if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer):
670 raise PydanticUserError(
671 '`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method'
672 )
673
674 sig = signature_no_eval(serializer)
675 info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
676 if info_arg is None:
677 raise PydanticUserError(
678 f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
679 code='model-serializer-signature',
680 )
681 else:
682 return info_arg
683
684
685def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
686 if mode == 'plain':
687 if n_positional == 1:
688 # (input_value: Any, /) -> Any
689 return False
690 elif n_positional == 2:
691 # (model: Any, input_value: Any, /) -> Any
692 return True
693 else:
694 assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
695 if n_positional == 2:
696 # (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any
697 return False
698 elif n_positional == 3:
699 # (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any
700 return True
701
702 return None
703
704
705AnyDecoratorCallable: TypeAlias = (
706 'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]'
707)
708
709
710def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool:
711 """Whether the function is an instance method.
712
713 It will consider a function as instance method if the first parameter of
714 function is `self`.
715
716 Args:
717 function: The function to check.
718
719 Returns:
720 `True` if the function is an instance method, `False` otherwise.
721 """
722 sig = signature_no_eval(unwrap_wrapped_function(function))
723 first = next(iter(sig.parameters.values()), None)
724 if first and first.name == 'self':
725 return True
726 return False
727
728
729def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any:
730 """Apply the `@classmethod` decorator on the function.
731
732 Args:
733 function: The function to apply the decorator on.
734
735 Return:
736 The `@classmethod` decorator applied function.
737 """
738 if not isinstance(
739 unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod
740 ) and _is_classmethod_from_sig(function):
741 return classmethod(function) # type: ignore[arg-type]
742 return function
743
744
745def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool:
746 sig = signature_no_eval(unwrap_wrapped_function(function))
747 first = next(iter(sig.parameters.values()), None)
748 if first and first.name == 'cls':
749 return True
750 return False
751
752
753def unwrap_wrapped_function(
754 func: Any,
755 *,
756 unwrap_partial: bool = True,
757 unwrap_class_static_method: bool = True,
758) -> Any:
759 """Recursively unwraps a wrapped function until the underlying function is reached.
760 This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod.
761
762 Args:
763 func: The function to unwrap.
764 unwrap_partial: If True (default), unwrap partial and partialmethod decorators.
765 unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
766 decorators. If False, only unwrap partial and partialmethod decorators.
767
768 Returns:
769 The underlying function of the wrapped function.
770 """
771 # Define the types we want to check against as a single tuple.
772 unwrap_types = (
773 (property, cached_property)
774 + ((partial, partialmethod) if unwrap_partial else ())
775 + ((staticmethod, classmethod) if unwrap_class_static_method else ())
776 )
777
778 while isinstance(func, unwrap_types):
779 if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
780 func = func.__func__
781 elif isinstance(func, (partial, partialmethod)):
782 func = func.func
783 elif isinstance(func, property):
784 func = func.fget # arbitrary choice, convenient for computed fields
785 else:
786 # Make coverage happy as it can only get here in the last possible case
787 assert isinstance(func, cached_property)
788 func = func.func # type: ignore
789
790 return func
791
792
793_function_like = (
794 partial,
795 partialmethod,
796 types.FunctionType,
797 types.BuiltinFunctionType,
798 types.MethodType,
799 types.WrapperDescriptorType,
800 types.MethodWrapperType,
801 types.MemberDescriptorType,
802)
803
804
805def get_callable_return_type(
806 callable_obj: Any,
807 globalns: GlobalsNamespace | None = None,
808 localns: MappingNamespace | None = None,
809) -> Any | PydanticUndefinedType:
810 """Get the callable return type.
811
812 Args:
813 callable_obj: The callable to analyze.
814 globalns: The globals namespace to use during type annotation evaluation.
815 localns: The locals namespace to use during type annotation evaluation.
816
817 Returns:
818 The function return type.
819 """
820 if isinstance(callable_obj, type):
821 # types are callables, and we assume the return type
822 # is the type itself (e.g. `int()` results in an instance of `int`).
823 return callable_obj
824
825 if not isinstance(callable_obj, _function_like):
826 call_func = getattr(type(callable_obj), '__call__', None) # noqa: B004
827 if call_func is not None:
828 callable_obj = call_func
829
830 hints = get_function_type_hints(
831 unwrap_wrapped_function(callable_obj),
832 include_keys={'return'},
833 globalns=globalns,
834 localns=localns,
835 )
836 return hints.get('return', PydanticUndefined)
837
838
839def count_positional_required_params(sig: Signature) -> int:
840 """Get the number of positional (required) arguments of a signature.
841
842 This function should only be used to inspect signatures of validation and serialization functions.
843 The first argument (the value being serialized or validated) is counted as a required argument
844 even if a default value exists.
845
846 Returns:
847 The number of positional arguments of a signature.
848 """
849 parameters = list(sig.parameters.values())
850 return sum(
851 1
852 for param in parameters
853 if can_be_positional(param)
854 # First argument is the value being validated/serialized, and can have a default value
855 # (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg
856 # for instance) should be required, and thus without any default value.
857 and (param.default is Parameter.empty or param is parameters[0])
858 )
859
860
861def ensure_property(f: Any) -> Any:
862 """Ensure that a function is a `property` or `cached_property`, or is a valid descriptor.
863
864 Args:
865 f: The function to check.
866
867 Returns:
868 The function, or a `property` or `cached_property` instance wrapping the function.
869 """
870 if ismethoddescriptor(f) or isdatadescriptor(f):
871 return f
872 else:
873 return property(f)