Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/pydantic/_internal/_decorators.py: 62%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

325 statements  

1"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators.""" 

2 

3from __future__ import annotations as _annotations 

4 

5import types 

6from collections import deque 

7from collections.abc import Iterable 

8from dataclasses import dataclass, field 

9from functools import cached_property, partial, partialmethod 

10from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature 

11from itertools import islice 

12from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, TypeVar, Union 

13 

14from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema 

15from typing_extensions import TypeAlias, is_typeddict 

16 

17from ..errors import PydanticUserError 

18from ._core_utils import get_type_ref 

19from ._internal_dataclass import slots_true 

20from ._namespace_utils import GlobalsNamespace, MappingNamespace 

21from ._typing_extra import get_function_type_hints 

22from ._utils import can_be_positional 

23 

24if TYPE_CHECKING: 

25 from ..fields import ComputedFieldInfo 

26 from ..functional_validators import FieldValidatorModes 

27 

28 

29@dataclass(**slots_true) 

30class ValidatorDecoratorInfo: 

31 """A container for data from `@validator` so that we can access it 

32 while building the pydantic-core schema. 

33 

34 Attributes: 

35 decorator_repr: A class variable representing the decorator string, '@validator'. 

36 fields: A tuple of field names the validator should be called on. 

37 mode: The proposed validator mode. 

38 each_item: For complex objects (sets, lists etc.) whether to validate individual 

39 elements rather than the whole object. 

40 always: Whether this method and other validators should be called even if the value is missing. 

41 check_fields: Whether to check that the fields actually exist on the model. 

42 """ 

43 

44 decorator_repr: ClassVar[str] = '@validator' 

45 

46 fields: tuple[str, ...] 

47 mode: Literal['before', 'after'] 

48 each_item: bool 

49 always: bool 

50 check_fields: bool | None 

51 

52 

53@dataclass(**slots_true) 

54class FieldValidatorDecoratorInfo: 

55 """A container for data from `@field_validator` so that we can access it 

56 while building the pydantic-core schema. 

57 

58 Attributes: 

59 decorator_repr: A class variable representing the decorator string, '@field_validator'. 

60 fields: A tuple of field names the validator should be called on. 

61 mode: The proposed validator mode. 

62 check_fields: Whether to check that the fields actually exist on the model. 

63 json_schema_input_type: The input type of the function. This is only used to generate 

64 the appropriate JSON Schema (in validation mode) and can only specified 

65 when `mode` is either `'before'`, `'plain'` or `'wrap'`. 

66 """ 

67 

68 decorator_repr: ClassVar[str] = '@field_validator' 

69 

70 fields: tuple[str, ...] 

71 mode: FieldValidatorModes 

72 check_fields: bool | None 

73 json_schema_input_type: Any 

74 

75 

76@dataclass(**slots_true) 

77class RootValidatorDecoratorInfo: 

78 """A container for data from `@root_validator` so that we can access it 

79 while building the pydantic-core schema. 

80 

81 Attributes: 

82 decorator_repr: A class variable representing the decorator string, '@root_validator'. 

83 mode: The proposed validator mode. 

84 """ 

85 

86 decorator_repr: ClassVar[str] = '@root_validator' 

87 mode: Literal['before', 'after'] 

88 

89 

90@dataclass(**slots_true) 

91class FieldSerializerDecoratorInfo: 

92 """A container for data from `@field_serializer` so that we can access it 

93 while building the pydantic-core schema. 

94 

95 Attributes: 

96 decorator_repr: A class variable representing the decorator string, '@field_serializer'. 

97 fields: A tuple of field names the serializer should be called on. 

98 mode: The proposed serializer mode. 

99 return_type: The type of the serializer's return value. 

100 when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`, 

101 and `'json-unless-none'`. 

102 check_fields: Whether to check that the fields actually exist on the model. 

103 """ 

104 

105 decorator_repr: ClassVar[str] = '@field_serializer' 

106 fields: tuple[str, ...] 

107 mode: Literal['plain', 'wrap'] 

108 return_type: Any 

109 when_used: core_schema.WhenUsed 

110 check_fields: bool | None 

111 

112 

113@dataclass(**slots_true) 

114class ModelSerializerDecoratorInfo: 

115 """A container for data from `@model_serializer` so that we can access it 

116 while building the pydantic-core schema. 

117 

118 Attributes: 

119 decorator_repr: A class variable representing the decorator string, '@model_serializer'. 

120 mode: The proposed serializer mode. 

121 return_type: The type of the serializer's return value. 

122 when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`, 

123 and `'json-unless-none'`. 

124 """ 

125 

126 decorator_repr: ClassVar[str] = '@model_serializer' 

127 mode: Literal['plain', 'wrap'] 

128 return_type: Any 

129 when_used: core_schema.WhenUsed 

130 

131 

132@dataclass(**slots_true) 

133class ModelValidatorDecoratorInfo: 

134 """A container for data from `@model_validator` so that we can access it 

135 while building the pydantic-core schema. 

136 

137 Attributes: 

138 decorator_repr: A class variable representing the decorator string, '@model_validator'. 

139 mode: The proposed serializer mode. 

140 """ 

141 

142 decorator_repr: ClassVar[str] = '@model_validator' 

143 mode: Literal['wrap', 'before', 'after'] 

144 

145 

146DecoratorInfo: TypeAlias = """Union[ 

147 ValidatorDecoratorInfo, 

148 FieldValidatorDecoratorInfo, 

149 RootValidatorDecoratorInfo, 

150 FieldSerializerDecoratorInfo, 

151 ModelSerializerDecoratorInfo, 

152 ModelValidatorDecoratorInfo, 

153 ComputedFieldInfo, 

154]""" 

155 

156ReturnType = TypeVar('ReturnType') 

157DecoratedType: TypeAlias = ( 

158 'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]' 

159) 

160 

161 

162@dataclass # can't use slots here since we set attributes on `__post_init__` 

163class PydanticDescriptorProxy(Generic[ReturnType]): 

164 """Wrap a classmethod, staticmethod, property or unbound function 

165 and act as a descriptor that allows us to detect decorated items 

166 from the class' attributes. 

167 

168 This class' __get__ returns the wrapped item's __get__ result, 

169 which makes it transparent for classmethods and staticmethods. 

170 

171 Attributes: 

172 wrapped: The decorator that has to be wrapped. 

173 decorator_info: The decorator info. 

174 shim: A wrapper function to wrap V1 style function. 

175 """ 

176 

177 wrapped: DecoratedType[ReturnType] 

178 decorator_info: DecoratorInfo 

179 shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None 

180 

181 def __post_init__(self): 

182 for attr in 'setter', 'deleter': 

183 if hasattr(self.wrapped, attr): 

184 f = partial(self._call_wrapped_attr, name=attr) 

185 setattr(self, attr, f) 

186 

187 def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]: 

188 self.wrapped = getattr(self.wrapped, name)(func) 

189 if isinstance(self.wrapped, property): 

190 # update ComputedFieldInfo.wrapped_property 

191 from ..fields import ComputedFieldInfo 

192 

193 if isinstance(self.decorator_info, ComputedFieldInfo): 

194 self.decorator_info.wrapped_property = self.wrapped 

195 return self 

196 

197 def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]: 

198 try: 

199 return self.wrapped.__get__(obj, obj_type) 

200 except AttributeError: 

201 # not a descriptor, e.g. a partial object 

202 return self.wrapped # type: ignore[return-value] 

203 

204 def __set_name__(self, instance: Any, name: str) -> None: 

205 if hasattr(self.wrapped, '__set_name__'): 

206 self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess] 

207 

208 def __getattr__(self, name: str, /) -> Any: 

209 """Forward checks for __isabstractmethod__ and such.""" 

210 return getattr(self.wrapped, name) 

211 

212 

213DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo) 

214 

215 

216@dataclass(**slots_true) 

217class Decorator(Generic[DecoratorInfoType]): 

218 """A generic container class to join together the decorator metadata 

219 (metadata from decorator itself, which we have when the 

220 decorator is called but not when we are building the core-schema) 

221 and the bound function (which we have after the class itself is created). 

222 

223 Attributes: 

224 cls_ref: The class ref. 

225 cls_var_name: The decorated function name. 

226 func: The decorated function. 

227 shim: A wrapper function to wrap V1 style function. 

228 info: The decorator info. 

229 """ 

230 

231 cls_ref: str 

232 cls_var_name: str 

233 func: Callable[..., Any] 

234 shim: Callable[[Any], Any] | None 

235 info: DecoratorInfoType 

236 

237 @staticmethod 

238 def build( 

239 cls_: Any, 

240 *, 

241 cls_var_name: str, 

242 shim: Callable[[Any], Any] | None, 

243 info: DecoratorInfoType, 

244 ) -> Decorator[DecoratorInfoType]: 

245 """Build a new decorator. 

246 

247 Args: 

248 cls_: The class. 

249 cls_var_name: The decorated function name. 

250 shim: A wrapper function to wrap V1 style function. 

251 info: The decorator info. 

252 

253 Returns: 

254 The new decorator instance. 

255 """ 

256 func = get_attribute_from_bases(cls_, cls_var_name) 

257 if shim is not None: 

258 func = shim(func) 

259 func = unwrap_wrapped_function(func, unwrap_partial=False) 

260 if not callable(func): 

261 # This branch will get hit for classmethod properties 

262 attribute = get_attribute_from_base_dicts(cls_, cls_var_name) # prevents the binding call to `__get__` 

263 if isinstance(attribute, PydanticDescriptorProxy): 

264 func = unwrap_wrapped_function(attribute.wrapped) 

265 return Decorator( 

266 cls_ref=get_type_ref(cls_), 

267 cls_var_name=cls_var_name, 

268 func=func, 

269 shim=shim, 

270 info=info, 

271 ) 

272 

273 def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]: 

274 """Bind the decorator to a class. 

275 

276 Args: 

277 cls: the class. 

278 

279 Returns: 

280 The new decorator instance. 

281 """ 

282 return self.build( 

283 cls, 

284 cls_var_name=self.cls_var_name, 

285 shim=self.shim, 

286 info=self.info, 

287 ) 

288 

289 

290def get_bases(tp: type[Any]) -> tuple[type[Any], ...]: 

291 """Get the base classes of a class or typeddict. 

292 

293 Args: 

294 tp: The type or class to get the bases. 

295 

296 Returns: 

297 The base classes. 

298 """ 

299 if is_typeddict(tp): 

300 return tp.__orig_bases__ # type: ignore 

301 try: 

302 return tp.__bases__ 

303 except AttributeError: 

304 return () 

305 

306 

307def mro(tp: type[Any]) -> tuple[type[Any], ...]: 

308 """Calculate the Method Resolution Order of bases using the C3 algorithm. 

309 

310 See https://www.python.org/download/releases/2.3/mro/ 

311 """ 

312 # try to use the existing mro, for performance mainly 

313 # but also because it helps verify the implementation below 

314 if not is_typeddict(tp): 

315 try: 

316 return tp.__mro__ 

317 except AttributeError: 

318 # GenericAlias and some other cases 

319 pass 

320 

321 bases = get_bases(tp) 

322 return (tp,) + mro_for_bases(bases) 

323 

324 

325def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]: 

326 def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]: 

327 while True: 

328 non_empty = [seq for seq in seqs if seq] 

329 if not non_empty: 

330 # Nothing left to process, we're done. 

331 return 

332 candidate: type[Any] | None = None 

333 for seq in non_empty: # Find merge candidates among seq heads. 

334 candidate = seq[0] 

335 not_head = [s for s in non_empty if candidate in islice(s, 1, None)] 

336 if not_head: 

337 # Reject the candidate. 

338 candidate = None 

339 else: 

340 break 

341 if not candidate: 

342 raise TypeError('Inconsistent hierarchy, no C3 MRO is possible') 

343 yield candidate 

344 for seq in non_empty: 

345 # Remove candidate. 

346 if seq[0] == candidate: 

347 seq.popleft() 

348 

349 seqs = [deque(mro(base)) for base in bases] + [deque(bases)] 

350 return tuple(merge_seqs(seqs)) 

351 

352 

353_sentinel = object() 

354 

355 

356def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any: 

357 """Get the attribute from the next class in the MRO that has it, 

358 aiming to simulate calling the method on the actual class. 

359 

360 The reason for iterating over the mro instead of just getting 

361 the attribute (which would do that for us) is to support TypedDict, 

362 which lacks a real __mro__, but can have a virtual one constructed 

363 from its bases (as done here). 

364 

365 Args: 

366 tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes. 

367 name: The name of the attribute to retrieve. 

368 

369 Returns: 

370 Any: The attribute value, if found. 

371 

372 Raises: 

373 AttributeError: If the attribute is not found in any class in the MRO. 

374 """ 

375 if isinstance(tp, tuple): 

376 for base in mro_for_bases(tp): 

377 attribute = base.__dict__.get(name, _sentinel) 

378 if attribute is not _sentinel: 

379 attribute_get = getattr(attribute, '__get__', None) 

380 if attribute_get is not None: 

381 return attribute_get(None, tp) 

382 return attribute 

383 raise AttributeError(f'{name} not found in {tp}') 

384 else: 

385 try: 

386 return getattr(tp, name) 

387 except AttributeError: 

388 return get_attribute_from_bases(mro(tp), name) 

389 

390 

391def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any: 

392 """Get an attribute out of the `__dict__` following the MRO. 

393 This prevents the call to `__get__` on the descriptor, and allows 

394 us to get the original function for classmethod properties. 

395 

396 Args: 

397 tp: The type or class to search for the attribute. 

398 name: The name of the attribute to retrieve. 

399 

400 Returns: 

401 Any: The attribute value, if found. 

402 

403 Raises: 

404 KeyError: If the attribute is not found in any class's `__dict__` in the MRO. 

405 """ 

406 for base in reversed(mro(tp)): 

407 if name in base.__dict__: 

408 return base.__dict__[name] 

409 return tp.__dict__[name] # raise the error 

410 

411 

412@dataclass(**slots_true) 

413class DecoratorInfos: 

414 """Mapping of name in the class namespace to decorator info. 

415 

416 note that the name in the class namespace is the function or attribute name 

417 not the field name! 

418 """ 

419 

420 validators: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict) 

421 field_validators: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict) 

422 root_validators: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict) 

423 field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict) 

424 model_serializers: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict) 

425 model_validators: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict) 

426 computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict) 

427 

428 @staticmethod 

429 def build(model_dc: type[Any]) -> DecoratorInfos: # noqa: C901 (ignore complexity) 

430 """We want to collect all DecFunc instances that exist as 

431 attributes in the namespace of the class (a BaseModel or dataclass) 

432 that called us 

433 But we want to collect these in the order of the bases 

434 So instead of getting them all from the leaf class (the class that called us), 

435 we traverse the bases from root (the oldest ancestor class) to leaf 

436 and collect all of the instances as we go, taking care to replace 

437 any duplicate ones with the last one we see to mimic how function overriding 

438 works with inheritance. 

439 If we do replace any functions we put the replacement into the position 

440 the replaced function was in; that is, we maintain the order. 

441 """ 

442 # reminder: dicts are ordered and replacement does not alter the order 

443 res = DecoratorInfos() 

444 for base in reversed(mro(model_dc)[1:]): 

445 existing: DecoratorInfos | None = base.__dict__.get('__pydantic_decorators__') 

446 if existing is None: 

447 existing = DecoratorInfos.build(base) 

448 res.validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.validators.items()}) 

449 res.field_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_validators.items()}) 

450 res.root_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.root_validators.items()}) 

451 res.field_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_serializers.items()}) 

452 res.model_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_serializers.items()}) 

453 res.model_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_validators.items()}) 

454 res.computed_fields.update({k: v.bind_to_cls(model_dc) for k, v in existing.computed_fields.items()}) 

455 

456 to_replace: list[tuple[str, Any]] = [] 

457 

458 for var_name, var_value in vars(model_dc).items(): 

459 if isinstance(var_value, PydanticDescriptorProxy): 

460 info = var_value.decorator_info 

461 if isinstance(info, ValidatorDecoratorInfo): 

462 res.validators[var_name] = Decorator.build( 

463 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

464 ) 

465 elif isinstance(info, FieldValidatorDecoratorInfo): 

466 res.field_validators[var_name] = Decorator.build( 

467 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

468 ) 

469 elif isinstance(info, RootValidatorDecoratorInfo): 

470 res.root_validators[var_name] = Decorator.build( 

471 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

472 ) 

473 elif isinstance(info, FieldSerializerDecoratorInfo): 

474 # check whether a serializer function is already registered for fields 

475 for field_serializer_decorator in res.field_serializers.values(): 

476 # check that each field has at most one serializer function. 

477 # serializer functions for the same field in subclasses are allowed, 

478 # and are treated as overrides 

479 if field_serializer_decorator.cls_var_name == var_name: 

480 continue 

481 for f in info.fields: 

482 if f in field_serializer_decorator.info.fields: 

483 raise PydanticUserError( 

484 'Multiple field serializer functions were defined ' 

485 f'for field {f!r}, this is not allowed.', 

486 code='multiple-field-serializers', 

487 ) 

488 res.field_serializers[var_name] = Decorator.build( 

489 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

490 ) 

491 elif isinstance(info, ModelValidatorDecoratorInfo): 

492 res.model_validators[var_name] = Decorator.build( 

493 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

494 ) 

495 elif isinstance(info, ModelSerializerDecoratorInfo): 

496 res.model_serializers[var_name] = Decorator.build( 

497 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

498 ) 

499 else: 

500 from ..fields import ComputedFieldInfo 

501 

502 isinstance(var_value, ComputedFieldInfo) 

503 res.computed_fields[var_name] = Decorator.build( 

504 model_dc, cls_var_name=var_name, shim=None, info=info 

505 ) 

506 to_replace.append((var_name, var_value.wrapped)) 

507 if to_replace: 

508 # If we can save `__pydantic_decorators__` on the class we'll be able to check for it above 

509 # so then we don't need to re-process the type, which means we can discard our descriptor wrappers 

510 # and replace them with the thing they are wrapping (see the other setattr call below) 

511 # which allows validator class methods to also function as regular class methods 

512 model_dc.__pydantic_decorators__ = res 

513 for name, value in to_replace: 

514 setattr(model_dc, name, value) 

515 return res 

516 

517 

518def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool: 

519 """Look at a field or model validator function and determine whether it takes an info argument. 

520 

521 An error is raised if the function has an invalid signature. 

522 

523 Args: 

524 validator: The validator function to inspect. 

525 mode: The proposed validator mode. 

526 

527 Returns: 

528 Whether the validator takes an info argument. 

529 """ 

530 try: 

531 sig = signature(validator) 

532 except (ValueError, TypeError): 

533 # `inspect.signature` might not be able to infer a signature, e.g. with C objects. 

534 # In this case, we assume no info argument is present: 

535 return False 

536 n_positional = count_positional_required_params(sig) 

537 if mode == 'wrap': 

538 if n_positional == 3: 

539 return True 

540 elif n_positional == 2: 

541 return False 

542 else: 

543 assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain" 

544 if n_positional == 2: 

545 return True 

546 elif n_positional == 1: 

547 return False 

548 

549 raise PydanticUserError( 

550 f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}', 

551 code='validator-signature', 

552 ) 

553 

554 

555def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]: 

556 """Look at a field serializer function and determine if it is a field serializer, 

557 and whether it takes an info argument. 

558 

559 An error is raised if the function has an invalid signature. 

560 

561 Args: 

562 serializer: The serializer function to inspect. 

563 mode: The serializer mode, either 'plain' or 'wrap'. 

564 

565 Returns: 

566 Tuple of (is_field_serializer, info_arg). 

567 """ 

568 try: 

569 sig = signature(serializer) 

570 except (ValueError, TypeError): 

571 # `inspect.signature` might not be able to infer a signature, e.g. with C objects. 

572 # In this case, we assume no info argument is present and this is not a method: 

573 return (False, False) 

574 

575 first = next(iter(sig.parameters.values()), None) 

576 is_field_serializer = first is not None and first.name == 'self' 

577 

578 n_positional = count_positional_required_params(sig) 

579 if is_field_serializer: 

580 # -1 to correct for self parameter 

581 info_arg = _serializer_info_arg(mode, n_positional - 1) 

582 else: 

583 info_arg = _serializer_info_arg(mode, n_positional) 

584 

585 if info_arg is None: 

586 raise PydanticUserError( 

587 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}', 

588 code='field-serializer-signature', 

589 ) 

590 

591 return is_field_serializer, info_arg 

592 

593 

594def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool: 

595 """Look at a serializer function used via `Annotated` and determine whether it takes an info argument. 

596 

597 An error is raised if the function has an invalid signature. 

598 

599 Args: 

600 serializer: The serializer function to check. 

601 mode: The serializer mode, either 'plain' or 'wrap'. 

602 

603 Returns: 

604 info_arg 

605 """ 

606 try: 

607 sig = signature(serializer) 

608 except (ValueError, TypeError): 

609 # `inspect.signature` might not be able to infer a signature, e.g. with C objects. 

610 # In this case, we assume no info argument is present: 

611 return False 

612 info_arg = _serializer_info_arg(mode, count_positional_required_params(sig)) 

613 if info_arg is None: 

614 raise PydanticUserError( 

615 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}', 

616 code='field-serializer-signature', 

617 ) 

618 else: 

619 return info_arg 

620 

621 

622def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool: 

623 """Look at a model serializer function and determine whether it takes an info argument. 

624 

625 An error is raised if the function has an invalid signature. 

626 

627 Args: 

628 serializer: The serializer function to check. 

629 mode: The serializer mode, either 'plain' or 'wrap'. 

630 

631 Returns: 

632 `info_arg` - whether the function expects an info argument. 

633 """ 

634 if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer): 

635 raise PydanticUserError( 

636 '`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method' 

637 ) 

638 

639 sig = signature(serializer) 

640 info_arg = _serializer_info_arg(mode, count_positional_required_params(sig)) 

641 if info_arg is None: 

642 raise PydanticUserError( 

643 f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}', 

644 code='model-serializer-signature', 

645 ) 

646 else: 

647 return info_arg 

648 

649 

650def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None: 

651 if mode == 'plain': 

652 if n_positional == 1: 

653 # (input_value: Any, /) -> Any 

654 return False 

655 elif n_positional == 2: 

656 # (model: Any, input_value: Any, /) -> Any 

657 return True 

658 else: 

659 assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'" 

660 if n_positional == 2: 

661 # (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any 

662 return False 

663 elif n_positional == 3: 

664 # (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any 

665 return True 

666 

667 return None 

668 

669 

670AnyDecoratorCallable: TypeAlias = ( 

671 'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]' 

672) 

673 

674 

675def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool: 

676 """Whether the function is an instance method. 

677 

678 It will consider a function as instance method if the first parameter of 

679 function is `self`. 

680 

681 Args: 

682 function: The function to check. 

683 

684 Returns: 

685 `True` if the function is an instance method, `False` otherwise. 

686 """ 

687 sig = signature(unwrap_wrapped_function(function)) 

688 first = next(iter(sig.parameters.values()), None) 

689 if first and first.name == 'self': 

690 return True 

691 return False 

692 

693 

694def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any: 

695 """Apply the `@classmethod` decorator on the function. 

696 

697 Args: 

698 function: The function to apply the decorator on. 

699 

700 Return: 

701 The `@classmethod` decorator applied function. 

702 """ 

703 if not isinstance( 

704 unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod 

705 ) and _is_classmethod_from_sig(function): 

706 return classmethod(function) # type: ignore[arg-type] 

707 return function 

708 

709 

710def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool: 

711 sig = signature(unwrap_wrapped_function(function)) 

712 first = next(iter(sig.parameters.values()), None) 

713 if first and first.name == 'cls': 

714 return True 

715 return False 

716 

717 

718def unwrap_wrapped_function( 

719 func: Any, 

720 *, 

721 unwrap_partial: bool = True, 

722 unwrap_class_static_method: bool = True, 

723) -> Any: 

724 """Recursively unwraps a wrapped function until the underlying function is reached. 

725 This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod. 

726 

727 Args: 

728 func: The function to unwrap. 

729 unwrap_partial: If True (default), unwrap partial and partialmethod decorators. 

730 unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod 

731 decorators. If False, only unwrap partial and partialmethod decorators. 

732 

733 Returns: 

734 The underlying function of the wrapped function. 

735 """ 

736 # Define the types we want to check against as a single tuple. 

737 unwrap_types = ( 

738 (property, cached_property) 

739 + ((partial, partialmethod) if unwrap_partial else ()) 

740 + ((staticmethod, classmethod) if unwrap_class_static_method else ()) 

741 ) 

742 

743 while isinstance(func, unwrap_types): 

744 if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)): 

745 func = func.__func__ 

746 elif isinstance(func, (partial, partialmethod)): 

747 func = func.func 

748 elif isinstance(func, property): 

749 func = func.fget # arbitrary choice, convenient for computed fields 

750 else: 

751 # Make coverage happy as it can only get here in the last possible case 

752 assert isinstance(func, cached_property) 

753 func = func.func # type: ignore 

754 

755 return func 

756 

757 

758_function_like = ( 

759 partial, 

760 partialmethod, 

761 types.FunctionType, 

762 types.BuiltinFunctionType, 

763 types.MethodType, 

764 types.WrapperDescriptorType, 

765 types.MethodWrapperType, 

766 types.MemberDescriptorType, 

767) 

768 

769 

770def get_callable_return_type( 

771 callable_obj: Any, 

772 globalns: GlobalsNamespace | None = None, 

773 localns: MappingNamespace | None = None, 

774) -> Any | PydanticUndefinedType: 

775 """Get the callable return type. 

776 

777 Args: 

778 callable_obj: The callable to analyze. 

779 globalns: The globals namespace to use during type annotation evaluation. 

780 localns: The locals namespace to use during type annotation evaluation. 

781 

782 Returns: 

783 The function return type. 

784 """ 

785 if isinstance(callable_obj, type): 

786 # types are callables, and we assume the return type 

787 # is the type itself (e.g. `int()` results in an instance of `int`). 

788 return callable_obj 

789 

790 if not isinstance(callable_obj, _function_like): 

791 call_func = getattr(type(callable_obj), '__call__', None) # noqa: B004 

792 if call_func is not None: 

793 callable_obj = call_func 

794 

795 hints = get_function_type_hints( 

796 unwrap_wrapped_function(callable_obj), 

797 include_keys={'return'}, 

798 globalns=globalns, 

799 localns=localns, 

800 ) 

801 return hints.get('return', PydanticUndefined) 

802 

803 

804def count_positional_required_params(sig: Signature) -> int: 

805 """Get the number of positional (required) arguments of a signature. 

806 

807 This function should only be used to inspect signatures of validation and serialization functions. 

808 The first argument (the value being serialized or validated) is counted as a required argument 

809 even if a default value exists. 

810 

811 Returns: 

812 The number of positional arguments of a signature. 

813 """ 

814 parameters = list(sig.parameters.values()) 

815 return sum( 

816 1 

817 for param in parameters 

818 if can_be_positional(param) 

819 # First argument is the value being validated/serialized, and can have a default value 

820 # (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg 

821 # for instance) should be required, and thus without any default value. 

822 and (param.default is Parameter.empty or param is parameters[0]) 

823 ) 

824 

825 

826def ensure_property(f: Any) -> Any: 

827 """Ensure that a function is a `property` or `cached_property`, or is a valid descriptor. 

828 

829 Args: 

830 f: The function to check. 

831 

832 Returns: 

833 The function, or a `property` or `cached_property` instance wrapping the function. 

834 """ 

835 if ismethoddescriptor(f) or isdatadescriptor(f): 

836 return f 

837 else: 

838 return property(f)