Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/pydantic/_internal/_decorators.py: 66%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

343 statements  

1"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators.""" 

2 

3from __future__ import annotations as _annotations 

4 

5import types 

6from collections import deque 

7from collections.abc import Iterable 

8from copy import copy 

9from dataclasses import dataclass, field 

10from functools import cached_property, partial, partialmethod 

11from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor 

12from itertools import islice 

13from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, TypeVar, Union 

14 

15from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema 

16from typing_extensions import Self, TypeAlias, is_typeddict 

17 

18from ..errors import PydanticUserError 

19from ._core_utils import get_type_ref 

20from ._internal_dataclass import slots_true 

21from ._namespace_utils import GlobalsNamespace, MappingNamespace 

22from ._typing_extra import get_function_type_hints, signature_no_eval 

23from ._utils import can_be_positional 

24 

25if TYPE_CHECKING: 

26 from ..fields import ComputedFieldInfo 

27 from ..functional_validators import FieldValidatorModes 

28 from ._config import ConfigWrapper 

29 

30 

31@dataclass(**slots_true) 

32class ValidatorDecoratorInfo: 

33 """A container for data from `@validator` so that we can access it 

34 while building the pydantic-core schema. 

35 

36 Attributes: 

37 decorator_repr: A class variable representing the decorator string, '@validator'. 

38 fields: A tuple of field names the validator should be called on. 

39 mode: The proposed validator mode. 

40 each_item: For complex objects (sets, lists etc.) whether to validate individual 

41 elements rather than the whole object. 

42 always: Whether this method and other validators should be called even if the value is missing. 

43 check_fields: Whether to check that the fields actually exist on the model. 

44 """ 

45 

46 decorator_repr: ClassVar[str] = '@validator' 

47 

48 fields: tuple[str, ...] 

49 mode: Literal['before', 'after'] 

50 each_item: bool 

51 always: bool 

52 check_fields: bool | None 

53 

54 

55@dataclass(**slots_true) 

56class FieldValidatorDecoratorInfo: 

57 """A container for data from `@field_validator` so that we can access it 

58 while building the pydantic-core schema. 

59 

60 Attributes: 

61 decorator_repr: A class variable representing the decorator string, '@field_validator'. 

62 fields: A tuple of field names the validator should be called on. 

63 mode: The proposed validator mode. 

64 check_fields: Whether to check that the fields actually exist on the model. 

65 json_schema_input_type: The input type of the function. This is only used to generate 

66 the appropriate JSON Schema (in validation mode) and can only specified 

67 when `mode` is either `'before'`, `'plain'` or `'wrap'`. 

68 """ 

69 

70 decorator_repr: ClassVar[str] = '@field_validator' 

71 

72 fields: tuple[str, ...] 

73 mode: FieldValidatorModes 

74 check_fields: bool | None 

75 json_schema_input_type: Any 

76 

77 

78@dataclass(**slots_true) 

79class RootValidatorDecoratorInfo: 

80 """A container for data from `@root_validator` so that we can access it 

81 while building the pydantic-core schema. 

82 

83 Attributes: 

84 decorator_repr: A class variable representing the decorator string, '@root_validator'. 

85 mode: The proposed validator mode. 

86 """ 

87 

88 decorator_repr: ClassVar[str] = '@root_validator' 

89 mode: Literal['before', 'after'] 

90 

91 

92@dataclass(**slots_true) 

93class FieldSerializerDecoratorInfo: 

94 """A container for data from `@field_serializer` so that we can access it 

95 while building the pydantic-core schema. 

96 

97 Attributes: 

98 decorator_repr: A class variable representing the decorator string, '@field_serializer'. 

99 fields: A tuple of field names the serializer should be called on. 

100 mode: The proposed serializer mode. 

101 return_type: The type of the serializer's return value. 

102 when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`, 

103 and `'json-unless-none'`. 

104 check_fields: Whether to check that the fields actually exist on the model. 

105 """ 

106 

107 decorator_repr: ClassVar[str] = '@field_serializer' 

108 fields: tuple[str, ...] 

109 mode: Literal['plain', 'wrap'] 

110 return_type: Any 

111 when_used: core_schema.WhenUsed 

112 check_fields: bool | None 

113 

114 

115@dataclass(**slots_true) 

116class ModelSerializerDecoratorInfo: 

117 """A container for data from `@model_serializer` so that we can access it 

118 while building the pydantic-core schema. 

119 

120 Attributes: 

121 decorator_repr: A class variable representing the decorator string, '@model_serializer'. 

122 mode: The proposed serializer mode. 

123 return_type: The type of the serializer's return value. 

124 when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`, 

125 and `'json-unless-none'`. 

126 """ 

127 

128 decorator_repr: ClassVar[str] = '@model_serializer' 

129 mode: Literal['plain', 'wrap'] 

130 return_type: Any 

131 when_used: core_schema.WhenUsed 

132 

133 

134@dataclass(**slots_true) 

135class ModelValidatorDecoratorInfo: 

136 """A container for data from `@model_validator` so that we can access it 

137 while building the pydantic-core schema. 

138 

139 Attributes: 

140 decorator_repr: A class variable representing the decorator string, '@model_validator'. 

141 mode: The proposed serializer mode. 

142 """ 

143 

144 decorator_repr: ClassVar[str] = '@model_validator' 

145 mode: Literal['wrap', 'before', 'after'] 

146 

147 

148DecoratorInfo: TypeAlias = """Union[ 

149 ValidatorDecoratorInfo, 

150 FieldValidatorDecoratorInfo, 

151 RootValidatorDecoratorInfo, 

152 FieldSerializerDecoratorInfo, 

153 ModelSerializerDecoratorInfo, 

154 ModelValidatorDecoratorInfo, 

155 ComputedFieldInfo, 

156]""" 

157 

158ReturnType = TypeVar('ReturnType') 

159DecoratedType: TypeAlias = ( 

160 'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]' 

161) 

162 

163 

164@dataclass # can't use slots here since we set attributes on `__post_init__` 

165class PydanticDescriptorProxy(Generic[ReturnType]): 

166 """Wrap a classmethod, staticmethod, property or unbound function 

167 and act as a descriptor that allows us to detect decorated items 

168 from the class' attributes. 

169 

170 This class' __get__ returns the wrapped item's __get__ result, 

171 which makes it transparent for classmethods and staticmethods. 

172 

173 Attributes: 

174 wrapped: The decorator that has to be wrapped. 

175 decorator_info: The decorator info. 

176 shim: A wrapper function to wrap V1 style function. 

177 """ 

178 

179 wrapped: DecoratedType[ReturnType] 

180 decorator_info: DecoratorInfo 

181 shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None 

182 

183 def __post_init__(self): 

184 for attr in 'setter', 'deleter': 

185 if hasattr(self.wrapped, attr): 

186 f = partial(self._call_wrapped_attr, name=attr) 

187 setattr(self, attr, f) 

188 

189 def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]: 

190 self.wrapped = getattr(self.wrapped, name)(func) 

191 if isinstance(self.wrapped, property): 

192 # update ComputedFieldInfo.wrapped_property 

193 from ..fields import ComputedFieldInfo 

194 

195 if isinstance(self.decorator_info, ComputedFieldInfo): 

196 self.decorator_info.wrapped_property = self.wrapped 

197 return self 

198 

199 def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]: 

200 try: 

201 return self.wrapped.__get__(obj, obj_type) # pyright: ignore[reportReturnType] 

202 except AttributeError: 

203 # not a descriptor, e.g. a partial object 

204 return self.wrapped # type: ignore[return-value] 

205 

206 def __set_name__(self, instance: Any, name: str) -> None: 

207 if hasattr(self.wrapped, '__set_name__'): 

208 self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess] 

209 

210 def __getattr__(self, name: str, /) -> Any: 

211 """Forward checks for __isabstractmethod__ and such.""" 

212 return getattr(self.wrapped, name) 

213 

214 

215DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo) 

216 

217 

218@dataclass(**slots_true) 

219class Decorator(Generic[DecoratorInfoType]): 

220 """A generic container class to join together the decorator metadata 

221 (metadata from decorator itself, which we have when the 

222 decorator is called but not when we are building the core-schema) 

223 and the bound function (which we have after the class itself is created). 

224 

225 Attributes: 

226 cls_ref: The class ref. 

227 cls_var_name: The decorated function name. 

228 func: The decorated function. 

229 shim: A wrapper function to wrap V1 style function. 

230 info: The decorator info. 

231 """ 

232 

233 cls_ref: str 

234 cls_var_name: str 

235 func: Callable[..., Any] 

236 shim: Callable[[Any], Any] | None 

237 info: DecoratorInfoType 

238 

239 @staticmethod 

240 def build( 

241 cls_: Any, 

242 *, 

243 cls_var_name: str, 

244 shim: Callable[[Any], Any] | None, 

245 info: DecoratorInfoType, 

246 ) -> Decorator[DecoratorInfoType]: 

247 """Build a new decorator. 

248 

249 Args: 

250 cls_: The class. 

251 cls_var_name: The decorated function name. 

252 shim: A wrapper function to wrap V1 style function. 

253 info: The decorator info. 

254 

255 Returns: 

256 The new decorator instance. 

257 """ 

258 func = get_attribute_from_bases(cls_, cls_var_name) 

259 if shim is not None: 

260 func = shim(func) 

261 func = unwrap_wrapped_function(func, unwrap_partial=False) 

262 if not callable(func): 

263 # TODO most likely this branch can be removed when we drop support for Python 3.12: 

264 # This branch will get hit for classmethod properties 

265 attribute = get_attribute_from_base_dicts(cls_, cls_var_name) # prevents the binding call to `__get__` 

266 if isinstance(attribute, PydanticDescriptorProxy): 

267 func = unwrap_wrapped_function(attribute.wrapped) 

268 return Decorator( 

269 cls_ref=get_type_ref(cls_), 

270 cls_var_name=cls_var_name, 

271 func=func, 

272 shim=shim, 

273 info=info, 

274 ) 

275 

276 def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]: 

277 """Bind the decorator to a class. 

278 

279 Args: 

280 cls: the class. 

281 

282 Returns: 

283 The new decorator instance. 

284 """ 

285 return self.build( 

286 cls, 

287 cls_var_name=self.cls_var_name, 

288 shim=self.shim, 

289 info=copy(self.info), 

290 ) 

291 

292 

293def get_bases(tp: type[Any]) -> tuple[type[Any], ...]: 

294 """Get the base classes of a class or typeddict. 

295 

296 Args: 

297 tp: The type or class to get the bases. 

298 

299 Returns: 

300 The base classes. 

301 """ 

302 if is_typeddict(tp): 

303 return tp.__orig_bases__ # type: ignore 

304 try: 

305 return tp.__bases__ 

306 except AttributeError: 

307 return () 

308 

309 

310def mro(tp: type[Any]) -> tuple[type[Any], ...]: 

311 """Calculate the Method Resolution Order of bases using the C3 algorithm. 

312 

313 See https://www.python.org/download/releases/2.3/mro/ 

314 """ 

315 # try to use the existing mro, for performance mainly 

316 # but also because it helps verify the implementation below 

317 if not is_typeddict(tp): 

318 try: 

319 return tp.__mro__ 

320 except AttributeError: 

321 # GenericAlias and some other cases 

322 pass 

323 

324 bases = get_bases(tp) 

325 return (tp,) + mro_for_bases(bases) 

326 

327 

328def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]: 

329 def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]: 

330 while True: 

331 non_empty = [seq for seq in seqs if seq] 

332 if not non_empty: 

333 # Nothing left to process, we're done. 

334 return 

335 candidate: type[Any] | None = None 

336 for seq in non_empty: # Find merge candidates among seq heads. 

337 candidate = seq[0] 

338 not_head = [s for s in non_empty if candidate in islice(s, 1, None)] 

339 if not_head: 

340 # Reject the candidate. 

341 candidate = None 

342 else: 

343 break 

344 if not candidate: 

345 raise TypeError('Inconsistent hierarchy, no C3 MRO is possible') 

346 yield candidate 

347 for seq in non_empty: 

348 # Remove candidate. 

349 if seq[0] == candidate: 

350 seq.popleft() 

351 

352 seqs = [deque(mro(base)) for base in bases] + [deque(bases)] 

353 return tuple(merge_seqs(seqs)) 

354 

355 

356_sentinel = object() 

357 

358 

359def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any: 

360 """Get the attribute from the next class in the MRO that has it, 

361 aiming to simulate calling the method on the actual class. 

362 

363 The reason for iterating over the mro instead of just getting 

364 the attribute (which would do that for us) is to support TypedDict, 

365 which lacks a real __mro__, but can have a virtual one constructed 

366 from its bases (as done here). 

367 

368 Args: 

369 tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes. 

370 name: The name of the attribute to retrieve. 

371 

372 Returns: 

373 Any: The attribute value, if found. 

374 

375 Raises: 

376 AttributeError: If the attribute is not found in any class in the MRO. 

377 """ 

378 if isinstance(tp, tuple): 

379 for base in mro_for_bases(tp): 

380 attribute = base.__dict__.get(name, _sentinel) 

381 if attribute is not _sentinel: 

382 attribute_get = getattr(attribute, '__get__', None) 

383 if attribute_get is not None: 

384 return attribute_get(None, tp) 

385 return attribute 

386 raise AttributeError(f'{name} not found in {tp}') 

387 else: 

388 try: 

389 return getattr(tp, name) 

390 except AttributeError: 

391 return get_attribute_from_bases(mro(tp), name) 

392 

393 

394def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any: 

395 """Get an attribute out of the `__dict__` following the MRO. 

396 This prevents the call to `__get__` on the descriptor, and allows 

397 us to get the original function for classmethod properties. 

398 

399 Args: 

400 tp: The type or class to search for the attribute. 

401 name: The name of the attribute to retrieve. 

402 

403 Returns: 

404 Any: The attribute value, if found. 

405 

406 Raises: 

407 KeyError: If the attribute is not found in any class's `__dict__` in the MRO. 

408 """ 

409 for base in reversed(mro(tp)): 

410 if name in base.__dict__: 

411 return base.__dict__[name] 

412 return tp.__dict__[name] # raise the error 

413 

414 

415@dataclass(**slots_true) 

416class DecoratorInfos: 

417 """Mapping of name in the class namespace to decorator info. 

418 

419 note that the name in the class namespace is the function or attribute name 

420 not the field name! 

421 """ 

422 

423 validators: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict) 

424 field_validators: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict) 

425 root_validators: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict) 

426 field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict) 

427 model_serializers: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict) 

428 model_validators: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict) 

429 computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict) 

430 

431 @classmethod 

432 def build( 

433 cls, 

434 typ: type[Any], 

435 # Default to `True` for backwards compatibility: 

436 replace_wrapped_methods: bool = True, 

437 ) -> Self: 

438 """Build a `DecoratorInfos` instance for the given model, dataclass or `TypedDict` type. 

439 

440 Decorators from parent classes are included, including "bare" classes (e.g. if `typ` 

441 is a Pydantic model, non Pydantic parent model classes are also taken into account). 

442 The collection of the decorators happens by respecting the MRO. 

443 

444 If one of the bases has an `__pydantic_decorators__` attribute set, it is assumed to be 

445 a `DecoratorInfos` instance and is used as-is. The `__pydantic_decorators__` attribute 

446 is *not* being set on the provided `typ`. 

447 

448 Args: 

449 typ: The model, dataclass or `TypedDict` type to use when building the `DecoratorInfos` instance. 

450 replace_wrapped_methods: Whether to replace the decorator's wrapped methods on `typ`. 

451 This is useful e.g. for field validators which are initially class methods. This should 

452 only be set to `True` if `typ` is a Pydantic model or dataclass (otherwise this results 

453 in mutations of classes Pydantic doesn't "own"). 

454 """ 

455 # reminder: dicts are ordered and replacement does not alter the order 

456 res = cls() 

457 # Iterate over the bases, without the actual `typ`. 

458 # `1:-1` because we don't need to include `object`/`TypedDict`: 

459 for base in reversed(mro(typ)[1:-1]): 

460 existing: DecoratorInfos | None = base.__dict__.get('__pydantic_decorators__') 

461 if existing is None: 

462 existing, _ = _decorator_infos_for_class(base, collect_to_replace=False) 

463 res.validators.update({k: v.bind_to_cls(typ) for k, v in existing.validators.items()}) 

464 res.field_validators.update({k: v.bind_to_cls(typ) for k, v in existing.field_validators.items()}) 

465 res.root_validators.update({k: v.bind_to_cls(typ) for k, v in existing.root_validators.items()}) 

466 res.field_serializers.update({k: v.bind_to_cls(typ) for k, v in existing.field_serializers.items()}) 

467 res.model_serializers.update({k: v.bind_to_cls(typ) for k, v in existing.model_serializers.items()}) 

468 res.model_validators.update({k: v.bind_to_cls(typ) for k, v in existing.model_validators.items()}) 

469 res.computed_fields.update({k: v.bind_to_cls(typ) for k, v in existing.computed_fields.items()}) 

470 

471 decorator_infos, to_replace = _decorator_infos_for_class(typ, collect_to_replace=True) 

472 

473 res.validators.update(decorator_infos.validators) 

474 res.field_validators.update(decorator_infos.field_validators) 

475 res.root_validators.update(decorator_infos.root_validators) 

476 res.field_serializers.update(decorator_infos.field_serializers) 

477 res.model_serializers.update(decorator_infos.model_serializers) 

478 res.model_validators.update(decorator_infos.model_validators) 

479 res.computed_fields.update(decorator_infos.computed_fields) 

480 

481 if replace_wrapped_methods and to_replace: 

482 for name, value in to_replace: 

483 setattr(typ, name, value) 

484 

485 res._validate() 

486 return res 

487 

488 def _validate(self) -> None: 

489 seen: set[str] = set() 

490 for field_ser in self.field_serializers.values(): 

491 for f_name in field_ser.info.fields: 

492 if f_name in seen: 

493 raise PydanticUserError( 

494 f'Multiple field serializer functions were defined for field {f_name!r}, this is not allowed.', 

495 code='multiple-field-serializers', 

496 ) 

497 seen.add(f_name) 

498 

499 def update_from_config(self, config_wrapper: ConfigWrapper) -> None: 

500 """Update the decorator infos from the configuration of the class they are attached to.""" 

501 for name, computed_field_dec in self.computed_fields.items(): 

502 computed_field_dec.info._update_from_config(config_wrapper, name) 

503 

504 

505def _decorator_infos_for_class( 

506 typ: type[Any], 

507 *, 

508 collect_to_replace: bool, 

509) -> tuple[DecoratorInfos, list[tuple[str, Any]]]: 

510 """Collect a `DecoratorInfos` for class, without looking into bases.""" 

511 res = DecoratorInfos() 

512 to_replace: list[tuple[str, Any]] = [] 

513 

514 for var_name, var_value in vars(typ).items(): 

515 if isinstance(var_value, PydanticDescriptorProxy): 

516 info = var_value.decorator_info 

517 if isinstance(info, ValidatorDecoratorInfo): 

518 res.validators[var_name] = Decorator.build(typ, cls_var_name=var_name, shim=var_value.shim, info=info) 

519 elif isinstance(info, FieldValidatorDecoratorInfo): 

520 res.field_validators[var_name] = Decorator.build( 

521 typ, cls_var_name=var_name, shim=var_value.shim, info=info 

522 ) 

523 elif isinstance(info, RootValidatorDecoratorInfo): 

524 res.root_validators[var_name] = Decorator.build( 

525 typ, cls_var_name=var_name, shim=var_value.shim, info=info 

526 ) 

527 elif isinstance(info, FieldSerializerDecoratorInfo): 

528 res.field_serializers[var_name] = Decorator.build( 

529 typ, cls_var_name=var_name, shim=var_value.shim, info=info 

530 ) 

531 elif isinstance(info, ModelValidatorDecoratorInfo): 

532 res.model_validators[var_name] = Decorator.build( 

533 typ, cls_var_name=var_name, shim=var_value.shim, info=info 

534 ) 

535 elif isinstance(info, ModelSerializerDecoratorInfo): 

536 res.model_serializers[var_name] = Decorator.build( 

537 typ, cls_var_name=var_name, shim=var_value.shim, info=info 

538 ) 

539 else: 

540 from ..fields import ComputedFieldInfo 

541 

542 isinstance(var_value, ComputedFieldInfo) 

543 res.computed_fields[var_name] = Decorator.build(typ, cls_var_name=var_name, shim=None, info=info) 

544 if collect_to_replace: 

545 to_replace.append((var_name, var_value.wrapped)) 

546 

547 return res, to_replace 

548 

549 

550def inspect_validator( 

551 validator: Callable[..., Any], *, mode: FieldValidatorModes, type: Literal['field', 'model'] 

552) -> bool: 

553 """Look at a field or model validator function and determine whether it takes an info argument. 

554 

555 An error is raised if the function has an invalid signature. 

556 

557 Args: 

558 validator: The validator function to inspect. 

559 mode: The proposed validator mode. 

560 type: The type of validator, either 'field' or 'model'. 

561 

562 Returns: 

563 Whether the validator takes an info argument. 

564 """ 

565 try: 

566 sig = signature_no_eval(validator) 

567 except (ValueError, TypeError): 

568 # `inspect.signature` might not be able to infer a signature, e.g. with C objects. 

569 # In this case, we assume no info argument is present: 

570 return False 

571 n_positional = count_positional_required_params(sig) 

572 if mode == 'wrap': 

573 if n_positional == 3: 

574 return True 

575 elif n_positional == 2: 

576 return False 

577 else: 

578 assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain" 

579 if n_positional == 2: 

580 return True 

581 elif n_positional == 1: 

582 return False 

583 

584 raise PydanticUserError( 

585 f'Unrecognized {type} validator function signature for {validator} with `mode={mode}`: {sig}', 

586 code='validator-signature', 

587 ) 

588 

589 

590def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]: 

591 """Look at a field serializer function and determine if it is a field serializer, 

592 and whether it takes an info argument. 

593 

594 An error is raised if the function has an invalid signature. 

595 

596 Args: 

597 serializer: The serializer function to inspect. 

598 mode: The serializer mode, either 'plain' or 'wrap'. 

599 

600 Returns: 

601 Tuple of (is_field_serializer, info_arg). 

602 """ 

603 try: 

604 sig = signature_no_eval(serializer) 

605 except (ValueError, TypeError): 

606 # `inspect.signature` might not be able to infer a signature, e.g. with C objects. 

607 # In this case, we assume no info argument is present and this is not a method: 

608 return (False, False) 

609 

610 first = next(iter(sig.parameters.values()), None) 

611 is_field_serializer = first is not None and first.name == 'self' 

612 

613 n_positional = count_positional_required_params(sig) 

614 if is_field_serializer: 

615 # -1 to correct for self parameter 

616 info_arg = _serializer_info_arg(mode, n_positional - 1) 

617 else: 

618 info_arg = _serializer_info_arg(mode, n_positional) 

619 

620 if info_arg is None: 

621 raise PydanticUserError( 

622 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}', 

623 code='field-serializer-signature', 

624 ) 

625 

626 return is_field_serializer, info_arg 

627 

628 

629def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool: 

630 """Look at a serializer function used via `Annotated` and determine whether it takes an info argument. 

631 

632 An error is raised if the function has an invalid signature. 

633 

634 Args: 

635 serializer: The serializer function to check. 

636 mode: The serializer mode, either 'plain' or 'wrap'. 

637 

638 Returns: 

639 info_arg 

640 """ 

641 try: 

642 sig = signature_no_eval(serializer) 

643 except (ValueError, TypeError): 

644 # `inspect.signature` might not be able to infer a signature, e.g. with C objects. 

645 # In this case, we assume no info argument is present: 

646 return False 

647 info_arg = _serializer_info_arg(mode, count_positional_required_params(sig)) 

648 if info_arg is None: 

649 raise PydanticUserError( 

650 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}', 

651 code='field-serializer-signature', 

652 ) 

653 else: 

654 return info_arg 

655 

656 

657def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool: 

658 """Look at a model serializer function and determine whether it takes an info argument. 

659 

660 An error is raised if the function has an invalid signature. 

661 

662 Args: 

663 serializer: The serializer function to check. 

664 mode: The serializer mode, either 'plain' or 'wrap'. 

665 

666 Returns: 

667 `info_arg` - whether the function expects an info argument. 

668 """ 

669 if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer): 

670 raise PydanticUserError( 

671 '`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method' 

672 ) 

673 

674 sig = signature_no_eval(serializer) 

675 info_arg = _serializer_info_arg(mode, count_positional_required_params(sig)) 

676 if info_arg is None: 

677 raise PydanticUserError( 

678 f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}', 

679 code='model-serializer-signature', 

680 ) 

681 else: 

682 return info_arg 

683 

684 

685def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None: 

686 if mode == 'plain': 

687 if n_positional == 1: 

688 # (input_value: Any, /) -> Any 

689 return False 

690 elif n_positional == 2: 

691 # (model: Any, input_value: Any, /) -> Any 

692 return True 

693 else: 

694 assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'" 

695 if n_positional == 2: 

696 # (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any 

697 return False 

698 elif n_positional == 3: 

699 # (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any 

700 return True 

701 

702 return None 

703 

704 

705AnyDecoratorCallable: TypeAlias = ( 

706 'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]' 

707) 

708 

709 

710def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool: 

711 """Whether the function is an instance method. 

712 

713 It will consider a function as instance method if the first parameter of 

714 function is `self`. 

715 

716 Args: 

717 function: The function to check. 

718 

719 Returns: 

720 `True` if the function is an instance method, `False` otherwise. 

721 """ 

722 sig = signature_no_eval(unwrap_wrapped_function(function)) 

723 first = next(iter(sig.parameters.values()), None) 

724 if first and first.name == 'self': 

725 return True 

726 return False 

727 

728 

729def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any: 

730 """Apply the `@classmethod` decorator on the function. 

731 

732 Args: 

733 function: The function to apply the decorator on. 

734 

735 Return: 

736 The `@classmethod` decorator applied function. 

737 """ 

738 if not isinstance( 

739 unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod 

740 ) and _is_classmethod_from_sig(function): 

741 return classmethod(function) # type: ignore[arg-type] 

742 return function 

743 

744 

745def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool: 

746 sig = signature_no_eval(unwrap_wrapped_function(function)) 

747 first = next(iter(sig.parameters.values()), None) 

748 if first and first.name == 'cls': 

749 return True 

750 return False 

751 

752 

753def unwrap_wrapped_function( 

754 func: Any, 

755 *, 

756 unwrap_partial: bool = True, 

757 unwrap_class_static_method: bool = True, 

758) -> Any: 

759 """Recursively unwraps a wrapped function until the underlying function is reached. 

760 This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod. 

761 

762 Args: 

763 func: The function to unwrap. 

764 unwrap_partial: If True (default), unwrap partial and partialmethod decorators. 

765 unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod 

766 decorators. If False, only unwrap partial and partialmethod decorators. 

767 

768 Returns: 

769 The underlying function of the wrapped function. 

770 """ 

771 # Define the types we want to check against as a single tuple. 

772 unwrap_types = ( 

773 (property, cached_property) 

774 + ((partial, partialmethod) if unwrap_partial else ()) 

775 + ((staticmethod, classmethod) if unwrap_class_static_method else ()) 

776 ) 

777 

778 while isinstance(func, unwrap_types): 

779 if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)): 

780 func = func.__func__ 

781 elif isinstance(func, (partial, partialmethod)): 

782 func = func.func 

783 elif isinstance(func, property): 

784 func = func.fget # arbitrary choice, convenient for computed fields 

785 else: 

786 # Make coverage happy as it can only get here in the last possible case 

787 assert isinstance(func, cached_property) 

788 func = func.func # type: ignore 

789 

790 return func 

791 

792 

793_function_like = ( 

794 partial, 

795 partialmethod, 

796 types.FunctionType, 

797 types.BuiltinFunctionType, 

798 types.MethodType, 

799 types.WrapperDescriptorType, 

800 types.MethodWrapperType, 

801 types.MemberDescriptorType, 

802) 

803 

804 

805def get_callable_return_type( 

806 callable_obj: Any, 

807 globalns: GlobalsNamespace | None = None, 

808 localns: MappingNamespace | None = None, 

809) -> Any | PydanticUndefinedType: 

810 """Get the callable return type. 

811 

812 Args: 

813 callable_obj: The callable to analyze. 

814 globalns: The globals namespace to use during type annotation evaluation. 

815 localns: The locals namespace to use during type annotation evaluation. 

816 

817 Returns: 

818 The function return type. 

819 """ 

820 if isinstance(callable_obj, type): 

821 # types are callables, and we assume the return type 

822 # is the type itself (e.g. `int()` results in an instance of `int`). 

823 return callable_obj 

824 

825 if not isinstance(callable_obj, _function_like): 

826 call_func = getattr(type(callable_obj), '__call__', None) # noqa: B004 

827 if call_func is not None: 

828 callable_obj = call_func 

829 

830 hints = get_function_type_hints( 

831 unwrap_wrapped_function(callable_obj), 

832 include_keys={'return'}, 

833 globalns=globalns, 

834 localns=localns, 

835 ) 

836 return hints.get('return', PydanticUndefined) 

837 

838 

839def count_positional_required_params(sig: Signature) -> int: 

840 """Get the number of positional (required) arguments of a signature. 

841 

842 This function should only be used to inspect signatures of validation and serialization functions. 

843 The first argument (the value being serialized or validated) is counted as a required argument 

844 even if a default value exists. 

845 

846 Returns: 

847 The number of positional arguments of a signature. 

848 """ 

849 parameters = list(sig.parameters.values()) 

850 return sum( 

851 1 

852 for param in parameters 

853 if can_be_positional(param) 

854 # First argument is the value being validated/serialized, and can have a default value 

855 # (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg 

856 # for instance) should be required, and thus without any default value. 

857 and (param.default is Parameter.empty or param is parameters[0]) 

858 ) 

859 

860 

861def ensure_property(f: Any) -> Any: 

862 """Ensure that a function is a `property` or `cached_property`, or is a valid descriptor. 

863 

864 Args: 

865 f: The function to check. 

866 

867 Returns: 

868 The function, or a `property` or `cached_property` instance wrapping the function. 

869 """ 

870 if ismethoddescriptor(f) or isdatadescriptor(f): 

871 return f 

872 else: 

873 return property(f)