1"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
2
3from __future__ import annotations as _annotations
4
5import types
6from collections import deque
7from collections.abc import Iterable
8from dataclasses import dataclass, field
9from functools import cached_property, partial, partialmethod
10from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
11from itertools import islice
12from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, TypeVar, Union
13
14from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema
15from typing_extensions import TypeAlias, is_typeddict
16
17from ..errors import PydanticUserError
18from ._core_utils import get_type_ref
19from ._internal_dataclass import slots_true
20from ._namespace_utils import GlobalsNamespace, MappingNamespace
21from ._typing_extra import get_function_type_hints
22from ._utils import can_be_positional
23
24if TYPE_CHECKING:
25 from ..fields import ComputedFieldInfo
26 from ..functional_validators import FieldValidatorModes
27
28
29@dataclass(**slots_true)
30class ValidatorDecoratorInfo:
31 """A container for data from `@validator` so that we can access it
32 while building the pydantic-core schema.
33
34 Attributes:
35 decorator_repr: A class variable representing the decorator string, '@validator'.
36 fields: A tuple of field names the validator should be called on.
37 mode: The proposed validator mode.
38 each_item: For complex objects (sets, lists etc.) whether to validate individual
39 elements rather than the whole object.
40 always: Whether this method and other validators should be called even if the value is missing.
41 check_fields: Whether to check that the fields actually exist on the model.
42 """
43
44 decorator_repr: ClassVar[str] = '@validator'
45
46 fields: tuple[str, ...]
47 mode: Literal['before', 'after']
48 each_item: bool
49 always: bool
50 check_fields: bool | None
51
52
53@dataclass(**slots_true)
54class FieldValidatorDecoratorInfo:
55 """A container for data from `@field_validator` so that we can access it
56 while building the pydantic-core schema.
57
58 Attributes:
59 decorator_repr: A class variable representing the decorator string, '@field_validator'.
60 fields: A tuple of field names the validator should be called on.
61 mode: The proposed validator mode.
62 check_fields: Whether to check that the fields actually exist on the model.
63 json_schema_input_type: The input type of the function. This is only used to generate
64 the appropriate JSON Schema (in validation mode) and can only specified
65 when `mode` is either `'before'`, `'plain'` or `'wrap'`.
66 """
67
68 decorator_repr: ClassVar[str] = '@field_validator'
69
70 fields: tuple[str, ...]
71 mode: FieldValidatorModes
72 check_fields: bool | None
73 json_schema_input_type: Any
74
75
76@dataclass(**slots_true)
77class RootValidatorDecoratorInfo:
78 """A container for data from `@root_validator` so that we can access it
79 while building the pydantic-core schema.
80
81 Attributes:
82 decorator_repr: A class variable representing the decorator string, '@root_validator'.
83 mode: The proposed validator mode.
84 """
85
86 decorator_repr: ClassVar[str] = '@root_validator'
87 mode: Literal['before', 'after']
88
89
90@dataclass(**slots_true)
91class FieldSerializerDecoratorInfo:
92 """A container for data from `@field_serializer` so that we can access it
93 while building the pydantic-core schema.
94
95 Attributes:
96 decorator_repr: A class variable representing the decorator string, '@field_serializer'.
97 fields: A tuple of field names the serializer should be called on.
98 mode: The proposed serializer mode.
99 return_type: The type of the serializer's return value.
100 when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
101 and `'json-unless-none'`.
102 check_fields: Whether to check that the fields actually exist on the model.
103 """
104
105 decorator_repr: ClassVar[str] = '@field_serializer'
106 fields: tuple[str, ...]
107 mode: Literal['plain', 'wrap']
108 return_type: Any
109 when_used: core_schema.WhenUsed
110 check_fields: bool | None
111
112
113@dataclass(**slots_true)
114class ModelSerializerDecoratorInfo:
115 """A container for data from `@model_serializer` so that we can access it
116 while building the pydantic-core schema.
117
118 Attributes:
119 decorator_repr: A class variable representing the decorator string, '@model_serializer'.
120 mode: The proposed serializer mode.
121 return_type: The type of the serializer's return value.
122 when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
123 and `'json-unless-none'`.
124 """
125
126 decorator_repr: ClassVar[str] = '@model_serializer'
127 mode: Literal['plain', 'wrap']
128 return_type: Any
129 when_used: core_schema.WhenUsed
130
131
132@dataclass(**slots_true)
133class ModelValidatorDecoratorInfo:
134 """A container for data from `@model_validator` so that we can access it
135 while building the pydantic-core schema.
136
137 Attributes:
138 decorator_repr: A class variable representing the decorator string, '@model_validator'.
139 mode: The proposed serializer mode.
140 """
141
142 decorator_repr: ClassVar[str] = '@model_validator'
143 mode: Literal['wrap', 'before', 'after']
144
145
146DecoratorInfo: TypeAlias = """Union[
147 ValidatorDecoratorInfo,
148 FieldValidatorDecoratorInfo,
149 RootValidatorDecoratorInfo,
150 FieldSerializerDecoratorInfo,
151 ModelSerializerDecoratorInfo,
152 ModelValidatorDecoratorInfo,
153 ComputedFieldInfo,
154]"""
155
156ReturnType = TypeVar('ReturnType')
157DecoratedType: TypeAlias = (
158 'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]'
159)
160
161
162@dataclass # can't use slots here since we set attributes on `__post_init__`
163class PydanticDescriptorProxy(Generic[ReturnType]):
164 """Wrap a classmethod, staticmethod, property or unbound function
165 and act as a descriptor that allows us to detect decorated items
166 from the class' attributes.
167
168 This class' __get__ returns the wrapped item's __get__ result,
169 which makes it transparent for classmethods and staticmethods.
170
171 Attributes:
172 wrapped: The decorator that has to be wrapped.
173 decorator_info: The decorator info.
174 shim: A wrapper function to wrap V1 style function.
175 """
176
177 wrapped: DecoratedType[ReturnType]
178 decorator_info: DecoratorInfo
179 shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None
180
181 def __post_init__(self):
182 for attr in 'setter', 'deleter':
183 if hasattr(self.wrapped, attr):
184 f = partial(self._call_wrapped_attr, name=attr)
185 setattr(self, attr, f)
186
187 def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
188 self.wrapped = getattr(self.wrapped, name)(func)
189 if isinstance(self.wrapped, property):
190 # update ComputedFieldInfo.wrapped_property
191 from ..fields import ComputedFieldInfo
192
193 if isinstance(self.decorator_info, ComputedFieldInfo):
194 self.decorator_info.wrapped_property = self.wrapped
195 return self
196
197 def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
198 try:
199 return self.wrapped.__get__(obj, obj_type)
200 except AttributeError:
201 # not a descriptor, e.g. a partial object
202 return self.wrapped # type: ignore[return-value]
203
204 def __set_name__(self, instance: Any, name: str) -> None:
205 if hasattr(self.wrapped, '__set_name__'):
206 self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
207
208 def __getattr__(self, name: str, /) -> Any:
209 """Forward checks for __isabstractmethod__ and such."""
210 return getattr(self.wrapped, name)
211
212
213DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
214
215
216@dataclass(**slots_true)
217class Decorator(Generic[DecoratorInfoType]):
218 """A generic container class to join together the decorator metadata
219 (metadata from decorator itself, which we have when the
220 decorator is called but not when we are building the core-schema)
221 and the bound function (which we have after the class itself is created).
222
223 Attributes:
224 cls_ref: The class ref.
225 cls_var_name: The decorated function name.
226 func: The decorated function.
227 shim: A wrapper function to wrap V1 style function.
228 info: The decorator info.
229 """
230
231 cls_ref: str
232 cls_var_name: str
233 func: Callable[..., Any]
234 shim: Callable[[Any], Any] | None
235 info: DecoratorInfoType
236
237 @staticmethod
238 def build(
239 cls_: Any,
240 *,
241 cls_var_name: str,
242 shim: Callable[[Any], Any] | None,
243 info: DecoratorInfoType,
244 ) -> Decorator[DecoratorInfoType]:
245 """Build a new decorator.
246
247 Args:
248 cls_: The class.
249 cls_var_name: The decorated function name.
250 shim: A wrapper function to wrap V1 style function.
251 info: The decorator info.
252
253 Returns:
254 The new decorator instance.
255 """
256 func = get_attribute_from_bases(cls_, cls_var_name)
257 if shim is not None:
258 func = shim(func)
259 func = unwrap_wrapped_function(func, unwrap_partial=False)
260 if not callable(func):
261 # This branch will get hit for classmethod properties
262 attribute = get_attribute_from_base_dicts(cls_, cls_var_name) # prevents the binding call to `__get__`
263 if isinstance(attribute, PydanticDescriptorProxy):
264 func = unwrap_wrapped_function(attribute.wrapped)
265 return Decorator(
266 cls_ref=get_type_ref(cls_),
267 cls_var_name=cls_var_name,
268 func=func,
269 shim=shim,
270 info=info,
271 )
272
273 def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]:
274 """Bind the decorator to a class.
275
276 Args:
277 cls: the class.
278
279 Returns:
280 The new decorator instance.
281 """
282 return self.build(
283 cls,
284 cls_var_name=self.cls_var_name,
285 shim=self.shim,
286 info=self.info,
287 )
288
289
290def get_bases(tp: type[Any]) -> tuple[type[Any], ...]:
291 """Get the base classes of a class or typeddict.
292
293 Args:
294 tp: The type or class to get the bases.
295
296 Returns:
297 The base classes.
298 """
299 if is_typeddict(tp):
300 return tp.__orig_bases__ # type: ignore
301 try:
302 return tp.__bases__
303 except AttributeError:
304 return ()
305
306
307def mro(tp: type[Any]) -> tuple[type[Any], ...]:
308 """Calculate the Method Resolution Order of bases using the C3 algorithm.
309
310 See https://www.python.org/download/releases/2.3/mro/
311 """
312 # try to use the existing mro, for performance mainly
313 # but also because it helps verify the implementation below
314 if not is_typeddict(tp):
315 try:
316 return tp.__mro__
317 except AttributeError:
318 # GenericAlias and some other cases
319 pass
320
321 bases = get_bases(tp)
322 return (tp,) + mro_for_bases(bases)
323
324
325def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]:
326 def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
327 while True:
328 non_empty = [seq for seq in seqs if seq]
329 if not non_empty:
330 # Nothing left to process, we're done.
331 return
332 candidate: type[Any] | None = None
333 for seq in non_empty: # Find merge candidates among seq heads.
334 candidate = seq[0]
335 not_head = [s for s in non_empty if candidate in islice(s, 1, None)]
336 if not_head:
337 # Reject the candidate.
338 candidate = None
339 else:
340 break
341 if not candidate:
342 raise TypeError('Inconsistent hierarchy, no C3 MRO is possible')
343 yield candidate
344 for seq in non_empty:
345 # Remove candidate.
346 if seq[0] == candidate:
347 seq.popleft()
348
349 seqs = [deque(mro(base)) for base in bases] + [deque(bases)]
350 return tuple(merge_seqs(seqs))
351
352
353_sentinel = object()
354
355
356def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any:
357 """Get the attribute from the next class in the MRO that has it,
358 aiming to simulate calling the method on the actual class.
359
360 The reason for iterating over the mro instead of just getting
361 the attribute (which would do that for us) is to support TypedDict,
362 which lacks a real __mro__, but can have a virtual one constructed
363 from its bases (as done here).
364
365 Args:
366 tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes.
367 name: The name of the attribute to retrieve.
368
369 Returns:
370 Any: The attribute value, if found.
371
372 Raises:
373 AttributeError: If the attribute is not found in any class in the MRO.
374 """
375 if isinstance(tp, tuple):
376 for base in mro_for_bases(tp):
377 attribute = base.__dict__.get(name, _sentinel)
378 if attribute is not _sentinel:
379 attribute_get = getattr(attribute, '__get__', None)
380 if attribute_get is not None:
381 return attribute_get(None, tp)
382 return attribute
383 raise AttributeError(f'{name} not found in {tp}')
384 else:
385 try:
386 return getattr(tp, name)
387 except AttributeError:
388 return get_attribute_from_bases(mro(tp), name)
389
390
391def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any:
392 """Get an attribute out of the `__dict__` following the MRO.
393 This prevents the call to `__get__` on the descriptor, and allows
394 us to get the original function for classmethod properties.
395
396 Args:
397 tp: The type or class to search for the attribute.
398 name: The name of the attribute to retrieve.
399
400 Returns:
401 Any: The attribute value, if found.
402
403 Raises:
404 KeyError: If the attribute is not found in any class's `__dict__` in the MRO.
405 """
406 for base in reversed(mro(tp)):
407 if name in base.__dict__:
408 return base.__dict__[name]
409 return tp.__dict__[name] # raise the error
410
411
412@dataclass(**slots_true)
413class DecoratorInfos:
414 """Mapping of name in the class namespace to decorator info.
415
416 note that the name in the class namespace is the function or attribute name
417 not the field name!
418 """
419
420 validators: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict)
421 field_validators: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict)
422 root_validators: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict)
423 field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict)
424 model_serializers: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict)
425 model_validators: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict)
426 computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict)
427
428 @staticmethod
429 def build(model_dc: type[Any]) -> DecoratorInfos: # noqa: C901 (ignore complexity)
430 """We want to collect all DecFunc instances that exist as
431 attributes in the namespace of the class (a BaseModel or dataclass)
432 that called us
433 But we want to collect these in the order of the bases
434 So instead of getting them all from the leaf class (the class that called us),
435 we traverse the bases from root (the oldest ancestor class) to leaf
436 and collect all of the instances as we go, taking care to replace
437 any duplicate ones with the last one we see to mimic how function overriding
438 works with inheritance.
439 If we do replace any functions we put the replacement into the position
440 the replaced function was in; that is, we maintain the order.
441 """
442 # reminder: dicts are ordered and replacement does not alter the order
443 res = DecoratorInfos()
444 for base in reversed(mro(model_dc)[1:]):
445 existing: DecoratorInfos | None = base.__dict__.get('__pydantic_decorators__')
446 if existing is None:
447 existing = DecoratorInfos.build(base)
448 res.validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.validators.items()})
449 res.field_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_validators.items()})
450 res.root_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.root_validators.items()})
451 res.field_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_serializers.items()})
452 res.model_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_serializers.items()})
453 res.model_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_validators.items()})
454 res.computed_fields.update({k: v.bind_to_cls(model_dc) for k, v in existing.computed_fields.items()})
455
456 to_replace: list[tuple[str, Any]] = []
457
458 for var_name, var_value in vars(model_dc).items():
459 if isinstance(var_value, PydanticDescriptorProxy):
460 info = var_value.decorator_info
461 if isinstance(info, ValidatorDecoratorInfo):
462 res.validators[var_name] = Decorator.build(
463 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
464 )
465 elif isinstance(info, FieldValidatorDecoratorInfo):
466 res.field_validators[var_name] = Decorator.build(
467 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
468 )
469 elif isinstance(info, RootValidatorDecoratorInfo):
470 res.root_validators[var_name] = Decorator.build(
471 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
472 )
473 elif isinstance(info, FieldSerializerDecoratorInfo):
474 # check whether a serializer function is already registered for fields
475 for field_serializer_decorator in res.field_serializers.values():
476 # check that each field has at most one serializer function.
477 # serializer functions for the same field in subclasses are allowed,
478 # and are treated as overrides
479 if field_serializer_decorator.cls_var_name == var_name:
480 continue
481 for f in info.fields:
482 if f in field_serializer_decorator.info.fields:
483 raise PydanticUserError(
484 'Multiple field serializer functions were defined '
485 f'for field {f!r}, this is not allowed.',
486 code='multiple-field-serializers',
487 )
488 res.field_serializers[var_name] = Decorator.build(
489 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
490 )
491 elif isinstance(info, ModelValidatorDecoratorInfo):
492 res.model_validators[var_name] = Decorator.build(
493 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
494 )
495 elif isinstance(info, ModelSerializerDecoratorInfo):
496 res.model_serializers[var_name] = Decorator.build(
497 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
498 )
499 else:
500 from ..fields import ComputedFieldInfo
501
502 isinstance(var_value, ComputedFieldInfo)
503 res.computed_fields[var_name] = Decorator.build(
504 model_dc, cls_var_name=var_name, shim=None, info=info
505 )
506 to_replace.append((var_name, var_value.wrapped))
507 if to_replace:
508 # If we can save `__pydantic_decorators__` on the class we'll be able to check for it above
509 # so then we don't need to re-process the type, which means we can discard our descriptor wrappers
510 # and replace them with the thing they are wrapping (see the other setattr call below)
511 # which allows validator class methods to also function as regular class methods
512 model_dc.__pydantic_decorators__ = res
513 for name, value in to_replace:
514 setattr(model_dc, name, value)
515 return res
516
517
518def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool:
519 """Look at a field or model validator function and determine whether it takes an info argument.
520
521 An error is raised if the function has an invalid signature.
522
523 Args:
524 validator: The validator function to inspect.
525 mode: The proposed validator mode.
526
527 Returns:
528 Whether the validator takes an info argument.
529 """
530 try:
531 sig = signature(validator)
532 except (ValueError, TypeError):
533 # `inspect.signature` might not be able to infer a signature, e.g. with C objects.
534 # In this case, we assume no info argument is present:
535 return False
536 n_positional = count_positional_required_params(sig)
537 if mode == 'wrap':
538 if n_positional == 3:
539 return True
540 elif n_positional == 2:
541 return False
542 else:
543 assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain"
544 if n_positional == 2:
545 return True
546 elif n_positional == 1:
547 return False
548
549 raise PydanticUserError(
550 f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}',
551 code='validator-signature',
552 )
553
554
555def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
556 """Look at a field serializer function and determine if it is a field serializer,
557 and whether it takes an info argument.
558
559 An error is raised if the function has an invalid signature.
560
561 Args:
562 serializer: The serializer function to inspect.
563 mode: The serializer mode, either 'plain' or 'wrap'.
564
565 Returns:
566 Tuple of (is_field_serializer, info_arg).
567 """
568 try:
569 sig = signature(serializer)
570 except (ValueError, TypeError):
571 # `inspect.signature` might not be able to infer a signature, e.g. with C objects.
572 # In this case, we assume no info argument is present and this is not a method:
573 return (False, False)
574
575 first = next(iter(sig.parameters.values()), None)
576 is_field_serializer = first is not None and first.name == 'self'
577
578 n_positional = count_positional_required_params(sig)
579 if is_field_serializer:
580 # -1 to correct for self parameter
581 info_arg = _serializer_info_arg(mode, n_positional - 1)
582 else:
583 info_arg = _serializer_info_arg(mode, n_positional)
584
585 if info_arg is None:
586 raise PydanticUserError(
587 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
588 code='field-serializer-signature',
589 )
590
591 return is_field_serializer, info_arg
592
593
594def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
595 """Look at a serializer function used via `Annotated` and determine whether it takes an info argument.
596
597 An error is raised if the function has an invalid signature.
598
599 Args:
600 serializer: The serializer function to check.
601 mode: The serializer mode, either 'plain' or 'wrap'.
602
603 Returns:
604 info_arg
605 """
606 try:
607 sig = signature(serializer)
608 except (ValueError, TypeError):
609 # `inspect.signature` might not be able to infer a signature, e.g. with C objects.
610 # In this case, we assume no info argument is present:
611 return False
612 info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
613 if info_arg is None:
614 raise PydanticUserError(
615 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
616 code='field-serializer-signature',
617 )
618 else:
619 return info_arg
620
621
622def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
623 """Look at a model serializer function and determine whether it takes an info argument.
624
625 An error is raised if the function has an invalid signature.
626
627 Args:
628 serializer: The serializer function to check.
629 mode: The serializer mode, either 'plain' or 'wrap'.
630
631 Returns:
632 `info_arg` - whether the function expects an info argument.
633 """
634 if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer):
635 raise PydanticUserError(
636 '`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method'
637 )
638
639 sig = signature(serializer)
640 info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
641 if info_arg is None:
642 raise PydanticUserError(
643 f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
644 code='model-serializer-signature',
645 )
646 else:
647 return info_arg
648
649
650def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
651 if mode == 'plain':
652 if n_positional == 1:
653 # (input_value: Any, /) -> Any
654 return False
655 elif n_positional == 2:
656 # (model: Any, input_value: Any, /) -> Any
657 return True
658 else:
659 assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
660 if n_positional == 2:
661 # (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any
662 return False
663 elif n_positional == 3:
664 # (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any
665 return True
666
667 return None
668
669
670AnyDecoratorCallable: TypeAlias = (
671 'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]'
672)
673
674
675def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool:
676 """Whether the function is an instance method.
677
678 It will consider a function as instance method if the first parameter of
679 function is `self`.
680
681 Args:
682 function: The function to check.
683
684 Returns:
685 `True` if the function is an instance method, `False` otherwise.
686 """
687 sig = signature(unwrap_wrapped_function(function))
688 first = next(iter(sig.parameters.values()), None)
689 if first and first.name == 'self':
690 return True
691 return False
692
693
694def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any:
695 """Apply the `@classmethod` decorator on the function.
696
697 Args:
698 function: The function to apply the decorator on.
699
700 Return:
701 The `@classmethod` decorator applied function.
702 """
703 if not isinstance(
704 unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod
705 ) and _is_classmethod_from_sig(function):
706 return classmethod(function) # type: ignore[arg-type]
707 return function
708
709
710def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool:
711 sig = signature(unwrap_wrapped_function(function))
712 first = next(iter(sig.parameters.values()), None)
713 if first and first.name == 'cls':
714 return True
715 return False
716
717
718def unwrap_wrapped_function(
719 func: Any,
720 *,
721 unwrap_partial: bool = True,
722 unwrap_class_static_method: bool = True,
723) -> Any:
724 """Recursively unwraps a wrapped function until the underlying function is reached.
725 This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod.
726
727 Args:
728 func: The function to unwrap.
729 unwrap_partial: If True (default), unwrap partial and partialmethod decorators.
730 unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
731 decorators. If False, only unwrap partial and partialmethod decorators.
732
733 Returns:
734 The underlying function of the wrapped function.
735 """
736 # Define the types we want to check against as a single tuple.
737 unwrap_types = (
738 (property, cached_property)
739 + ((partial, partialmethod) if unwrap_partial else ())
740 + ((staticmethod, classmethod) if unwrap_class_static_method else ())
741 )
742
743 while isinstance(func, unwrap_types):
744 if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
745 func = func.__func__
746 elif isinstance(func, (partial, partialmethod)):
747 func = func.func
748 elif isinstance(func, property):
749 func = func.fget # arbitrary choice, convenient for computed fields
750 else:
751 # Make coverage happy as it can only get here in the last possible case
752 assert isinstance(func, cached_property)
753 func = func.func # type: ignore
754
755 return func
756
757
758_function_like = (
759 partial,
760 partialmethod,
761 types.FunctionType,
762 types.BuiltinFunctionType,
763 types.MethodType,
764 types.WrapperDescriptorType,
765 types.MethodWrapperType,
766 types.MemberDescriptorType,
767)
768
769
770def get_callable_return_type(
771 callable_obj: Any,
772 globalns: GlobalsNamespace | None = None,
773 localns: MappingNamespace | None = None,
774) -> Any | PydanticUndefinedType:
775 """Get the callable return type.
776
777 Args:
778 callable_obj: The callable to analyze.
779 globalns: The globals namespace to use during type annotation evaluation.
780 localns: The locals namespace to use during type annotation evaluation.
781
782 Returns:
783 The function return type.
784 """
785 if isinstance(callable_obj, type):
786 # types are callables, and we assume the return type
787 # is the type itself (e.g. `int()` results in an instance of `int`).
788 return callable_obj
789
790 if not isinstance(callable_obj, _function_like):
791 call_func = getattr(type(callable_obj), '__call__', None) # noqa: B004
792 if call_func is not None:
793 callable_obj = call_func
794
795 hints = get_function_type_hints(
796 unwrap_wrapped_function(callable_obj),
797 include_keys={'return'},
798 globalns=globalns,
799 localns=localns,
800 )
801 return hints.get('return', PydanticUndefined)
802
803
804def count_positional_required_params(sig: Signature) -> int:
805 """Get the number of positional (required) arguments of a signature.
806
807 This function should only be used to inspect signatures of validation and serialization functions.
808 The first argument (the value being serialized or validated) is counted as a required argument
809 even if a default value exists.
810
811 Returns:
812 The number of positional arguments of a signature.
813 """
814 parameters = list(sig.parameters.values())
815 return sum(
816 1
817 for param in parameters
818 if can_be_positional(param)
819 # First argument is the value being validated/serialized, and can have a default value
820 # (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg
821 # for instance) should be required, and thus without any default value.
822 and (param.default is Parameter.empty or param is parameters[0])
823 )
824
825
826def ensure_property(f: Any) -> Any:
827 """Ensure that a function is a `property` or `cached_property`, or is a valid descriptor.
828
829 Args:
830 f: The function to check.
831
832 Returns:
833 The function, or a `property` or `cached_property` instance wrapping the function.
834 """
835 if ismethoddescriptor(f) or isdatadescriptor(f):
836 return f
837 else:
838 return property(f)