1"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
2
3from __future__ import annotations as _annotations
4
5import types
6from collections import deque
7from collections.abc import Iterable
8from dataclasses import dataclass, field
9from functools import cached_property, partial, partialmethod
10from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
11from itertools import islice
12from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, TypeVar, Union
13
14from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema
15from typing_extensions import TypeAlias, is_typeddict
16
17from ..errors import PydanticUserError
18from ._core_utils import get_type_ref
19from ._internal_dataclass import slots_true
20from ._namespace_utils import GlobalsNamespace, MappingNamespace
21from ._typing_extra import get_function_type_hints
22from ._utils import can_be_positional
23
24if TYPE_CHECKING:
25 from ..fields import ComputedFieldInfo
26 from ..functional_validators import FieldValidatorModes
27 from ._config import ConfigWrapper
28
29
30@dataclass(**slots_true)
31class ValidatorDecoratorInfo:
32 """A container for data from `@validator` so that we can access it
33 while building the pydantic-core schema.
34
35 Attributes:
36 decorator_repr: A class variable representing the decorator string, '@validator'.
37 fields: A tuple of field names the validator should be called on.
38 mode: The proposed validator mode.
39 each_item: For complex objects (sets, lists etc.) whether to validate individual
40 elements rather than the whole object.
41 always: Whether this method and other validators should be called even if the value is missing.
42 check_fields: Whether to check that the fields actually exist on the model.
43 """
44
45 decorator_repr: ClassVar[str] = '@validator'
46
47 fields: tuple[str, ...]
48 mode: Literal['before', 'after']
49 each_item: bool
50 always: bool
51 check_fields: bool | None
52
53
54@dataclass(**slots_true)
55class FieldValidatorDecoratorInfo:
56 """A container for data from `@field_validator` so that we can access it
57 while building the pydantic-core schema.
58
59 Attributes:
60 decorator_repr: A class variable representing the decorator string, '@field_validator'.
61 fields: A tuple of field names the validator should be called on.
62 mode: The proposed validator mode.
63 check_fields: Whether to check that the fields actually exist on the model.
64 json_schema_input_type: The input type of the function. This is only used to generate
65 the appropriate JSON Schema (in validation mode) and can only specified
66 when `mode` is either `'before'`, `'plain'` or `'wrap'`.
67 """
68
69 decorator_repr: ClassVar[str] = '@field_validator'
70
71 fields: tuple[str, ...]
72 mode: FieldValidatorModes
73 check_fields: bool | None
74 json_schema_input_type: Any
75
76
77@dataclass(**slots_true)
78class RootValidatorDecoratorInfo:
79 """A container for data from `@root_validator` so that we can access it
80 while building the pydantic-core schema.
81
82 Attributes:
83 decorator_repr: A class variable representing the decorator string, '@root_validator'.
84 mode: The proposed validator mode.
85 """
86
87 decorator_repr: ClassVar[str] = '@root_validator'
88 mode: Literal['before', 'after']
89
90
91@dataclass(**slots_true)
92class FieldSerializerDecoratorInfo:
93 """A container for data from `@field_serializer` so that we can access it
94 while building the pydantic-core schema.
95
96 Attributes:
97 decorator_repr: A class variable representing the decorator string, '@field_serializer'.
98 fields: A tuple of field names the serializer should be called on.
99 mode: The proposed serializer mode.
100 return_type: The type of the serializer's return value.
101 when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
102 and `'json-unless-none'`.
103 check_fields: Whether to check that the fields actually exist on the model.
104 """
105
106 decorator_repr: ClassVar[str] = '@field_serializer'
107 fields: tuple[str, ...]
108 mode: Literal['plain', 'wrap']
109 return_type: Any
110 when_used: core_schema.WhenUsed
111 check_fields: bool | None
112
113
114@dataclass(**slots_true)
115class ModelSerializerDecoratorInfo:
116 """A container for data from `@model_serializer` so that we can access it
117 while building the pydantic-core schema.
118
119 Attributes:
120 decorator_repr: A class variable representing the decorator string, '@model_serializer'.
121 mode: The proposed serializer mode.
122 return_type: The type of the serializer's return value.
123 when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
124 and `'json-unless-none'`.
125 """
126
127 decorator_repr: ClassVar[str] = '@model_serializer'
128 mode: Literal['plain', 'wrap']
129 return_type: Any
130 when_used: core_schema.WhenUsed
131
132
133@dataclass(**slots_true)
134class ModelValidatorDecoratorInfo:
135 """A container for data from `@model_validator` so that we can access it
136 while building the pydantic-core schema.
137
138 Attributes:
139 decorator_repr: A class variable representing the decorator string, '@model_validator'.
140 mode: The proposed serializer mode.
141 """
142
143 decorator_repr: ClassVar[str] = '@model_validator'
144 mode: Literal['wrap', 'before', 'after']
145
146
147DecoratorInfo: TypeAlias = """Union[
148 ValidatorDecoratorInfo,
149 FieldValidatorDecoratorInfo,
150 RootValidatorDecoratorInfo,
151 FieldSerializerDecoratorInfo,
152 ModelSerializerDecoratorInfo,
153 ModelValidatorDecoratorInfo,
154 ComputedFieldInfo,
155]"""
156
157ReturnType = TypeVar('ReturnType')
158DecoratedType: TypeAlias = (
159 'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]'
160)
161
162
163@dataclass # can't use slots here since we set attributes on `__post_init__`
164class PydanticDescriptorProxy(Generic[ReturnType]):
165 """Wrap a classmethod, staticmethod, property or unbound function
166 and act as a descriptor that allows us to detect decorated items
167 from the class' attributes.
168
169 This class' __get__ returns the wrapped item's __get__ result,
170 which makes it transparent for classmethods and staticmethods.
171
172 Attributes:
173 wrapped: The decorator that has to be wrapped.
174 decorator_info: The decorator info.
175 shim: A wrapper function to wrap V1 style function.
176 """
177
178 wrapped: DecoratedType[ReturnType]
179 decorator_info: DecoratorInfo
180 shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None
181
182 def __post_init__(self):
183 for attr in 'setter', 'deleter':
184 if hasattr(self.wrapped, attr):
185 f = partial(self._call_wrapped_attr, name=attr)
186 setattr(self, attr, f)
187
188 def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
189 self.wrapped = getattr(self.wrapped, name)(func)
190 if isinstance(self.wrapped, property):
191 # update ComputedFieldInfo.wrapped_property
192 from ..fields import ComputedFieldInfo
193
194 if isinstance(self.decorator_info, ComputedFieldInfo):
195 self.decorator_info.wrapped_property = self.wrapped
196 return self
197
198 def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
199 try:
200 return self.wrapped.__get__(obj, obj_type) # pyright: ignore[reportReturnType]
201 except AttributeError:
202 # not a descriptor, e.g. a partial object
203 return self.wrapped # type: ignore[return-value]
204
205 def __set_name__(self, instance: Any, name: str) -> None:
206 if hasattr(self.wrapped, '__set_name__'):
207 self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
208
209 def __getattr__(self, name: str, /) -> Any:
210 """Forward checks for __isabstractmethod__ and such."""
211 return getattr(self.wrapped, name)
212
213
214DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
215
216
217@dataclass(**slots_true)
218class Decorator(Generic[DecoratorInfoType]):
219 """A generic container class to join together the decorator metadata
220 (metadata from decorator itself, which we have when the
221 decorator is called but not when we are building the core-schema)
222 and the bound function (which we have after the class itself is created).
223
224 Attributes:
225 cls_ref: The class ref.
226 cls_var_name: The decorated function name.
227 func: The decorated function.
228 shim: A wrapper function to wrap V1 style function.
229 info: The decorator info.
230 """
231
232 cls_ref: str
233 cls_var_name: str
234 func: Callable[..., Any]
235 shim: Callable[[Any], Any] | None
236 info: DecoratorInfoType
237
238 @staticmethod
239 def build(
240 cls_: Any,
241 *,
242 cls_var_name: str,
243 shim: Callable[[Any], Any] | None,
244 info: DecoratorInfoType,
245 ) -> Decorator[DecoratorInfoType]:
246 """Build a new decorator.
247
248 Args:
249 cls_: The class.
250 cls_var_name: The decorated function name.
251 shim: A wrapper function to wrap V1 style function.
252 info: The decorator info.
253
254 Returns:
255 The new decorator instance.
256 """
257 func = get_attribute_from_bases(cls_, cls_var_name)
258 if shim is not None:
259 func = shim(func)
260 func = unwrap_wrapped_function(func, unwrap_partial=False)
261 if not callable(func):
262 # This branch will get hit for classmethod properties
263 attribute = get_attribute_from_base_dicts(cls_, cls_var_name) # prevents the binding call to `__get__`
264 if isinstance(attribute, PydanticDescriptorProxy):
265 func = unwrap_wrapped_function(attribute.wrapped)
266 return Decorator(
267 cls_ref=get_type_ref(cls_),
268 cls_var_name=cls_var_name,
269 func=func,
270 shim=shim,
271 info=info,
272 )
273
274 def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]:
275 """Bind the decorator to a class.
276
277 Args:
278 cls: the class.
279
280 Returns:
281 The new decorator instance.
282 """
283 return self.build(
284 cls,
285 cls_var_name=self.cls_var_name,
286 shim=self.shim,
287 info=self.info,
288 )
289
290
291def get_bases(tp: type[Any]) -> tuple[type[Any], ...]:
292 """Get the base classes of a class or typeddict.
293
294 Args:
295 tp: The type or class to get the bases.
296
297 Returns:
298 The base classes.
299 """
300 if is_typeddict(tp):
301 return tp.__orig_bases__ # type: ignore
302 try:
303 return tp.__bases__
304 except AttributeError:
305 return ()
306
307
308def mro(tp: type[Any]) -> tuple[type[Any], ...]:
309 """Calculate the Method Resolution Order of bases using the C3 algorithm.
310
311 See https://www.python.org/download/releases/2.3/mro/
312 """
313 # try to use the existing mro, for performance mainly
314 # but also because it helps verify the implementation below
315 if not is_typeddict(tp):
316 try:
317 return tp.__mro__
318 except AttributeError:
319 # GenericAlias and some other cases
320 pass
321
322 bases = get_bases(tp)
323 return (tp,) + mro_for_bases(bases)
324
325
326def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]:
327 def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
328 while True:
329 non_empty = [seq for seq in seqs if seq]
330 if not non_empty:
331 # Nothing left to process, we're done.
332 return
333 candidate: type[Any] | None = None
334 for seq in non_empty: # Find merge candidates among seq heads.
335 candidate = seq[0]
336 not_head = [s for s in non_empty if candidate in islice(s, 1, None)]
337 if not_head:
338 # Reject the candidate.
339 candidate = None
340 else:
341 break
342 if not candidate:
343 raise TypeError('Inconsistent hierarchy, no C3 MRO is possible')
344 yield candidate
345 for seq in non_empty:
346 # Remove candidate.
347 if seq[0] == candidate:
348 seq.popleft()
349
350 seqs = [deque(mro(base)) for base in bases] + [deque(bases)]
351 return tuple(merge_seqs(seqs))
352
353
354_sentinel = object()
355
356
357def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any:
358 """Get the attribute from the next class in the MRO that has it,
359 aiming to simulate calling the method on the actual class.
360
361 The reason for iterating over the mro instead of just getting
362 the attribute (which would do that for us) is to support TypedDict,
363 which lacks a real __mro__, but can have a virtual one constructed
364 from its bases (as done here).
365
366 Args:
367 tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes.
368 name: The name of the attribute to retrieve.
369
370 Returns:
371 Any: The attribute value, if found.
372
373 Raises:
374 AttributeError: If the attribute is not found in any class in the MRO.
375 """
376 if isinstance(tp, tuple):
377 for base in mro_for_bases(tp):
378 attribute = base.__dict__.get(name, _sentinel)
379 if attribute is not _sentinel:
380 attribute_get = getattr(attribute, '__get__', None)
381 if attribute_get is not None:
382 return attribute_get(None, tp)
383 return attribute
384 raise AttributeError(f'{name} not found in {tp}')
385 else:
386 try:
387 return getattr(tp, name)
388 except AttributeError:
389 return get_attribute_from_bases(mro(tp), name)
390
391
392def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any:
393 """Get an attribute out of the `__dict__` following the MRO.
394 This prevents the call to `__get__` on the descriptor, and allows
395 us to get the original function for classmethod properties.
396
397 Args:
398 tp: The type or class to search for the attribute.
399 name: The name of the attribute to retrieve.
400
401 Returns:
402 Any: The attribute value, if found.
403
404 Raises:
405 KeyError: If the attribute is not found in any class's `__dict__` in the MRO.
406 """
407 for base in reversed(mro(tp)):
408 if name in base.__dict__:
409 return base.__dict__[name]
410 return tp.__dict__[name] # raise the error
411
412
413@dataclass(**slots_true)
414class DecoratorInfos:
415 """Mapping of name in the class namespace to decorator info.
416
417 note that the name in the class namespace is the function or attribute name
418 not the field name!
419 """
420
421 validators: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict)
422 field_validators: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict)
423 root_validators: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict)
424 field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict)
425 model_serializers: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict)
426 model_validators: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict)
427 computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict)
428
429 @staticmethod
430 def build(model_dc: type[Any]) -> DecoratorInfos: # noqa: C901 (ignore complexity)
431 """We want to collect all DecFunc instances that exist as
432 attributes in the namespace of the class (a BaseModel or dataclass)
433 that called us
434 But we want to collect these in the order of the bases
435 So instead of getting them all from the leaf class (the class that called us),
436 we traverse the bases from root (the oldest ancestor class) to leaf
437 and collect all of the instances as we go, taking care to replace
438 any duplicate ones with the last one we see to mimic how function overriding
439 works with inheritance.
440 If we do replace any functions we put the replacement into the position
441 the replaced function was in; that is, we maintain the order.
442 """
443 # reminder: dicts are ordered and replacement does not alter the order
444 res = DecoratorInfos()
445 for base in reversed(mro(model_dc)[1:]):
446 existing: DecoratorInfos | None = base.__dict__.get('__pydantic_decorators__')
447 if existing is None:
448 existing = DecoratorInfos.build(base)
449 res.validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.validators.items()})
450 res.field_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_validators.items()})
451 res.root_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.root_validators.items()})
452 res.field_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_serializers.items()})
453 res.model_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_serializers.items()})
454 res.model_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_validators.items()})
455 res.computed_fields.update({k: v.bind_to_cls(model_dc) for k, v in existing.computed_fields.items()})
456
457 to_replace: list[tuple[str, Any]] = []
458
459 for var_name, var_value in vars(model_dc).items():
460 if isinstance(var_value, PydanticDescriptorProxy):
461 info = var_value.decorator_info
462 if isinstance(info, ValidatorDecoratorInfo):
463 res.validators[var_name] = Decorator.build(
464 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
465 )
466 elif isinstance(info, FieldValidatorDecoratorInfo):
467 res.field_validators[var_name] = Decorator.build(
468 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
469 )
470 elif isinstance(info, RootValidatorDecoratorInfo):
471 res.root_validators[var_name] = Decorator.build(
472 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
473 )
474 elif isinstance(info, FieldSerializerDecoratorInfo):
475 # check whether a serializer function is already registered for fields
476 for field_serializer_decorator in res.field_serializers.values():
477 # check that each field has at most one serializer function.
478 # serializer functions for the same field in subclasses are allowed,
479 # and are treated as overrides
480 if field_serializer_decorator.cls_var_name == var_name:
481 continue
482 for f in info.fields:
483 if f in field_serializer_decorator.info.fields:
484 raise PydanticUserError(
485 'Multiple field serializer functions were defined '
486 f'for field {f!r}, this is not allowed.',
487 code='multiple-field-serializers',
488 )
489 res.field_serializers[var_name] = Decorator.build(
490 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
491 )
492 elif isinstance(info, ModelValidatorDecoratorInfo):
493 res.model_validators[var_name] = Decorator.build(
494 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
495 )
496 elif isinstance(info, ModelSerializerDecoratorInfo):
497 res.model_serializers[var_name] = Decorator.build(
498 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
499 )
500 else:
501 from ..fields import ComputedFieldInfo
502
503 isinstance(var_value, ComputedFieldInfo)
504 res.computed_fields[var_name] = Decorator.build(
505 model_dc, cls_var_name=var_name, shim=None, info=info
506 )
507 to_replace.append((var_name, var_value.wrapped))
508 if to_replace:
509 # If we can save `__pydantic_decorators__` on the class we'll be able to check for it above
510 # so then we don't need to re-process the type, which means we can discard our descriptor wrappers
511 # and replace them with the thing they are wrapping (see the other setattr call below)
512 # which allows validator class methods to also function as regular class methods
513 model_dc.__pydantic_decorators__ = res
514 for name, value in to_replace:
515 setattr(model_dc, name, value)
516 return res
517
518 def update_from_config(self, config_wrapper: ConfigWrapper) -> None:
519 """Update the decorator infos from the configuration of the class they are attached to."""
520 for name, computed_field_dec in self.computed_fields.items():
521 computed_field_dec.info._update_from_config(config_wrapper, name)
522
523
524def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool:
525 """Look at a field or model validator function and determine whether it takes an info argument.
526
527 An error is raised if the function has an invalid signature.
528
529 Args:
530 validator: The validator function to inspect.
531 mode: The proposed validator mode.
532
533 Returns:
534 Whether the validator takes an info argument.
535 """
536 try:
537 sig = signature(validator)
538 except (ValueError, TypeError):
539 # `inspect.signature` might not be able to infer a signature, e.g. with C objects.
540 # In this case, we assume no info argument is present:
541 return False
542 n_positional = count_positional_required_params(sig)
543 if mode == 'wrap':
544 if n_positional == 3:
545 return True
546 elif n_positional == 2:
547 return False
548 else:
549 assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain"
550 if n_positional == 2:
551 return True
552 elif n_positional == 1:
553 return False
554
555 raise PydanticUserError(
556 f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}',
557 code='validator-signature',
558 )
559
560
561def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
562 """Look at a field serializer function and determine if it is a field serializer,
563 and whether it takes an info argument.
564
565 An error is raised if the function has an invalid signature.
566
567 Args:
568 serializer: The serializer function to inspect.
569 mode: The serializer mode, either 'plain' or 'wrap'.
570
571 Returns:
572 Tuple of (is_field_serializer, info_arg).
573 """
574 try:
575 sig = signature(serializer)
576 except (ValueError, TypeError):
577 # `inspect.signature` might not be able to infer a signature, e.g. with C objects.
578 # In this case, we assume no info argument is present and this is not a method:
579 return (False, False)
580
581 first = next(iter(sig.parameters.values()), None)
582 is_field_serializer = first is not None and first.name == 'self'
583
584 n_positional = count_positional_required_params(sig)
585 if is_field_serializer:
586 # -1 to correct for self parameter
587 info_arg = _serializer_info_arg(mode, n_positional - 1)
588 else:
589 info_arg = _serializer_info_arg(mode, n_positional)
590
591 if info_arg is None:
592 raise PydanticUserError(
593 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
594 code='field-serializer-signature',
595 )
596
597 return is_field_serializer, info_arg
598
599
600def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
601 """Look at a serializer function used via `Annotated` and determine whether it takes an info argument.
602
603 An error is raised if the function has an invalid signature.
604
605 Args:
606 serializer: The serializer function to check.
607 mode: The serializer mode, either 'plain' or 'wrap'.
608
609 Returns:
610 info_arg
611 """
612 try:
613 sig = signature(serializer)
614 except (ValueError, TypeError):
615 # `inspect.signature` might not be able to infer a signature, e.g. with C objects.
616 # In this case, we assume no info argument is present:
617 return False
618 info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
619 if info_arg is None:
620 raise PydanticUserError(
621 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
622 code='field-serializer-signature',
623 )
624 else:
625 return info_arg
626
627
628def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
629 """Look at a model serializer function and determine whether it takes an info argument.
630
631 An error is raised if the function has an invalid signature.
632
633 Args:
634 serializer: The serializer function to check.
635 mode: The serializer mode, either 'plain' or 'wrap'.
636
637 Returns:
638 `info_arg` - whether the function expects an info argument.
639 """
640 if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer):
641 raise PydanticUserError(
642 '`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method'
643 )
644
645 sig = signature(serializer)
646 info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
647 if info_arg is None:
648 raise PydanticUserError(
649 f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
650 code='model-serializer-signature',
651 )
652 else:
653 return info_arg
654
655
656def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
657 if mode == 'plain':
658 if n_positional == 1:
659 # (input_value: Any, /) -> Any
660 return False
661 elif n_positional == 2:
662 # (model: Any, input_value: Any, /) -> Any
663 return True
664 else:
665 assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
666 if n_positional == 2:
667 # (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any
668 return False
669 elif n_positional == 3:
670 # (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any
671 return True
672
673 return None
674
675
676AnyDecoratorCallable: TypeAlias = (
677 'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]'
678)
679
680
681def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool:
682 """Whether the function is an instance method.
683
684 It will consider a function as instance method if the first parameter of
685 function is `self`.
686
687 Args:
688 function: The function to check.
689
690 Returns:
691 `True` if the function is an instance method, `False` otherwise.
692 """
693 sig = signature(unwrap_wrapped_function(function))
694 first = next(iter(sig.parameters.values()), None)
695 if first and first.name == 'self':
696 return True
697 return False
698
699
700def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any:
701 """Apply the `@classmethod` decorator on the function.
702
703 Args:
704 function: The function to apply the decorator on.
705
706 Return:
707 The `@classmethod` decorator applied function.
708 """
709 if not isinstance(
710 unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod
711 ) and _is_classmethod_from_sig(function):
712 return classmethod(function) # type: ignore[arg-type]
713 return function
714
715
716def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool:
717 sig = signature(unwrap_wrapped_function(function))
718 first = next(iter(sig.parameters.values()), None)
719 if first and first.name == 'cls':
720 return True
721 return False
722
723
724def unwrap_wrapped_function(
725 func: Any,
726 *,
727 unwrap_partial: bool = True,
728 unwrap_class_static_method: bool = True,
729) -> Any:
730 """Recursively unwraps a wrapped function until the underlying function is reached.
731 This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod.
732
733 Args:
734 func: The function to unwrap.
735 unwrap_partial: If True (default), unwrap partial and partialmethod decorators.
736 unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
737 decorators. If False, only unwrap partial and partialmethod decorators.
738
739 Returns:
740 The underlying function of the wrapped function.
741 """
742 # Define the types we want to check against as a single tuple.
743 unwrap_types = (
744 (property, cached_property)
745 + ((partial, partialmethod) if unwrap_partial else ())
746 + ((staticmethod, classmethod) if unwrap_class_static_method else ())
747 )
748
749 while isinstance(func, unwrap_types):
750 if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
751 func = func.__func__
752 elif isinstance(func, (partial, partialmethod)):
753 func = func.func
754 elif isinstance(func, property):
755 func = func.fget # arbitrary choice, convenient for computed fields
756 else:
757 # Make coverage happy as it can only get here in the last possible case
758 assert isinstance(func, cached_property)
759 func = func.func # type: ignore
760
761 return func
762
763
764_function_like = (
765 partial,
766 partialmethod,
767 types.FunctionType,
768 types.BuiltinFunctionType,
769 types.MethodType,
770 types.WrapperDescriptorType,
771 types.MethodWrapperType,
772 types.MemberDescriptorType,
773)
774
775
776def get_callable_return_type(
777 callable_obj: Any,
778 globalns: GlobalsNamespace | None = None,
779 localns: MappingNamespace | None = None,
780) -> Any | PydanticUndefinedType:
781 """Get the callable return type.
782
783 Args:
784 callable_obj: The callable to analyze.
785 globalns: The globals namespace to use during type annotation evaluation.
786 localns: The locals namespace to use during type annotation evaluation.
787
788 Returns:
789 The function return type.
790 """
791 if isinstance(callable_obj, type):
792 # types are callables, and we assume the return type
793 # is the type itself (e.g. `int()` results in an instance of `int`).
794 return callable_obj
795
796 if not isinstance(callable_obj, _function_like):
797 call_func = getattr(type(callable_obj), '__call__', None) # noqa: B004
798 if call_func is not None:
799 callable_obj = call_func
800
801 hints = get_function_type_hints(
802 unwrap_wrapped_function(callable_obj),
803 include_keys={'return'},
804 globalns=globalns,
805 localns=localns,
806 )
807 return hints.get('return', PydanticUndefined)
808
809
810def count_positional_required_params(sig: Signature) -> int:
811 """Get the number of positional (required) arguments of a signature.
812
813 This function should only be used to inspect signatures of validation and serialization functions.
814 The first argument (the value being serialized or validated) is counted as a required argument
815 even if a default value exists.
816
817 Returns:
818 The number of positional arguments of a signature.
819 """
820 parameters = list(sig.parameters.values())
821 return sum(
822 1
823 for param in parameters
824 if can_be_positional(param)
825 # First argument is the value being validated/serialized, and can have a default value
826 # (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg
827 # for instance) should be required, and thus without any default value.
828 and (param.default is Parameter.empty or param is parameters[0])
829 )
830
831
832def ensure_property(f: Any) -> Any:
833 """Ensure that a function is a `property` or `cached_property`, or is a valid descriptor.
834
835 Args:
836 f: The function to check.
837
838 Returns:
839 The function, or a `property` or `cached_property` instance wrapping the function.
840 """
841 if ismethoddescriptor(f) or isdatadescriptor(f):
842 return f
843 else:
844 return property(f)