Coverage for /pythoncovmergedfiles/medio/medio/src/pydantic/pydantic/_internal/_decorators.py: 49%

227 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-27 07:38 +0000

1""" 

2Logic related to validators applied to models etc. via the `@validator` and `@root_validator` decorators. 

3""" 

4from __future__ import annotations as _annotations 

5 

6from dataclasses import dataclass, field 

7from functools import partial, partialmethod 

8from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature 

9from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeVar, Union, cast 

10 

11from pydantic_core import core_schema 

12from typing_extensions import Literal, TypeAlias 

13 

14from ..errors import PydanticUserError 

15from ..fields import ComputedFieldInfo 

16from ._core_utils import get_type_ref 

17from ._internal_dataclass import slots_dataclass 

18 

19if TYPE_CHECKING: 

20 from ..decorators import FieldValidatorModes 

21 

22try: 

23 from functools import cached_property # type: ignore 

24except ImportError: 

25 # python 3.7 

26 cached_property = None 

27 

28 

29@slots_dataclass 

30class ValidatorDecoratorInfo: 

31 """ 

32 A container for data from `@validator` so that we can access it 

33 while building the pydantic-core schema. 

34 """ 

35 

36 decorator_repr: ClassVar[str] = '@validator' 

37 

38 fields: tuple[str, ...] 

39 mode: Literal['before', 'after'] 

40 each_item: bool 

41 always: bool 

42 check_fields: bool | None 

43 

44 

45@slots_dataclass 

46class FieldValidatorDecoratorInfo: 

47 """ 

48 A container for data from `@field_validator` so that we can access it 

49 while building the pydantic-core schema. 

50 """ 

51 

52 decorator_repr: ClassVar[str] = '@field_validator' 

53 

54 fields: tuple[str, ...] 

55 mode: FieldValidatorModes 

56 check_fields: bool | None 

57 

58 

59@slots_dataclass 

60class RootValidatorDecoratorInfo: 

61 """ 

62 A container for data from `@root_validator` so that we can access it 

63 while building the pydantic-core schema. 

64 """ 

65 

66 decorator_repr: ClassVar[str] = '@root_validator' 

67 mode: Literal['before', 'after'] 

68 

69 

70@slots_dataclass 

71class FieldSerializerDecoratorInfo: 

72 """ 

73 A container for data from `@field_serializer` so that we can access it 

74 while building the pydantic-core schema. 

75 """ 

76 

77 decorator_repr: ClassVar[str] = '@field_serializer' 

78 fields: tuple[str, ...] 

79 mode: Literal['plain', 'wrap'] 

80 json_return_type: core_schema.JsonReturnTypes | None 

81 when_used: core_schema.WhenUsed 

82 check_fields: bool | None 

83 

84 

85@slots_dataclass 

86class ModelSerializerDecoratorInfo: 

87 """ 

88 A container for data from `@model_serializer` so that we can access it 

89 while building the pydantic-core schema. 

90 """ 

91 

92 decorator_repr: ClassVar[str] = '@model_serializer' 

93 mode: Literal['plain', 'wrap'] 

94 json_return_type: core_schema.JsonReturnTypes | None 

95 

96 

97@slots_dataclass 

98class ModelValidatorDecoratorInfo: 

99 """ 

100 A container for data from `@model_validator` so that we can access it 

101 while building the pydantic-core schema. 

102 """ 

103 

104 decorator_repr: ClassVar[str] = '@model_validator' 

105 mode: Literal['wrap', 'before', 'after'] 

106 

107 

108DecoratorInfo = Union[ 

109 ValidatorDecoratorInfo, 

110 FieldValidatorDecoratorInfo, 

111 RootValidatorDecoratorInfo, 

112 FieldSerializerDecoratorInfo, 

113 ModelSerializerDecoratorInfo, 

114 ModelValidatorDecoratorInfo, 

115 ComputedFieldInfo, 

116] 

117 

118ReturnType = TypeVar('ReturnType') 

119DecoratedType: TypeAlias = ( 

120 'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]' 

121) 

122 

123 

124@dataclass # can't use slots here since we set attributes on `__post_init__` 

125class PydanticDescriptorProxy(Generic[ReturnType]): 

126 """ 

127 Wrap a classmethod, staticmethod, property or unbound function 

128 and act as a descriptor that allows us to detect decorated items 

129 from the class' attributes. 

130 

131 This class' __get__ returns the wrapped item's __get__ result, 

132 which makes it transparent for classmethods and staticmethods. 

133 """ 

134 

135 wrapped: DecoratedType[ReturnType] 

136 decorator_info: DecoratorInfo 

137 shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None 

138 

139 def __post_init__(self): 

140 for attr in 'setter', 'deleter': 

141 if hasattr(self.wrapped, attr): 

142 f = partial(self._call_wrapped_attr, name=attr) 

143 setattr(self, attr, f) 

144 

145 def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]: 

146 self.wrapped = getattr(self.wrapped, name)(func) 

147 return self 

148 

149 def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]: 

150 try: 

151 return self.wrapped.__get__(obj, obj_type) 

152 except AttributeError: 

153 # not a descriptor, e.g. a partial object 

154 return self.wrapped # type: ignore[return-value] 

155 

156 def __set_name__(self, instance: Any, name: str) -> None: 

157 if hasattr(self.wrapped, '__set_name__'): 

158 self.wrapped.__set_name__(instance, name) 

159 

160 def __getattr__(self, __name: str) -> Any: 

161 """Forward checks for __isabstractmethod__ and such""" 

162 return getattr(self.wrapped, __name) 

163 

164 

165DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo) 

166 

167 

168@slots_dataclass 

169class Decorator(Generic[DecoratorInfoType]): 

170 """ 

171 A generic container class to join together the decorator metadata 

172 (metadata from decorator itself, which we have when the 

173 decorator is called but not when we are building the core-schema) 

174 and the bound function (which we have after the class itself is created). 

175 """ 

176 

177 cls_ref: str 

178 cls_var_name: str 

179 func: Callable[..., Any] 

180 shim: Callable[[Any], Any] | None 

181 info: DecoratorInfoType 

182 

183 @staticmethod 

184 def build( 

185 cls_: Any, 

186 *, 

187 cls_var_name: str, 

188 shim: Callable[[Any], Any] | None, 

189 info: DecoratorInfoType, 

190 ) -> Decorator[DecoratorInfoType]: 

191 func = getattr(cls_, cls_var_name) 

192 if shim is not None: 

193 func = shim(func) 

194 return Decorator( 

195 cls_ref=get_type_ref(cls_), 

196 cls_var_name=cls_var_name, 

197 func=func, 

198 shim=shim, 

199 info=info, 

200 ) 

201 

202 def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]: 

203 return self.build( 

204 cls, 

205 cls_var_name=self.cls_var_name, 

206 shim=self.shim, 

207 info=self.info, 

208 ) 

209 

210 

211@slots_dataclass 

212class DecoratorInfos: 

213 # mapping of name in the class namespace to decorator info 

214 # note that the name in the class namespace is the function or attribute name 

215 # not the field name! 

216 # TODO these all need to be renamed to plural 

217 validator: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict) 

218 field_validator: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict) 

219 root_validator: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict) 

220 field_serializer: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict) 

221 model_serializer: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict) 

222 model_validator: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict) 

223 computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict) 

224 

225 @staticmethod 

226 def build(model_dc: type[Any]) -> DecoratorInfos: # noqa: C901 (ignore complexity) 

227 """ 

228 We want to collect all DecFunc instances that exist as 

229 attributes in the namespace of the class (a BaseModel or dataclass) 

230 that called us 

231 But we want to collect these in the order of the bases 

232 So instead of getting them all from the leaf class (the class that called us), 

233 we traverse the bases from root (the oldest ancestor class) to leaf 

234 and collect all of the instances as we go, taking care to replace 

235 any duplicate ones with the last one we see to mimic how function overriding 

236 works with inheritance. 

237 If we do replace any functions we put the replacement into the position 

238 the replaced function was in; that is, we maintain the order. 

239 """ 

240 

241 # reminder: dicts are ordered and replacement does not alter the order 

242 res = DecoratorInfos() 

243 for base in model_dc.__bases__[::-1]: 

244 existing = cast(Union[DecoratorInfos, None], getattr(base, '__pydantic_decorators__', None)) 

245 if existing is not None: 

246 res.validator.update({k: v.bind_to_cls(model_dc) for k, v in existing.validator.items()}) 

247 res.field_validator.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_validator.items()}) 

248 res.root_validator.update({k: v.bind_to_cls(model_dc) for k, v in existing.root_validator.items()}) 

249 res.field_serializer.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_serializer.items()}) 

250 res.model_serializer.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_serializer.items()}) 

251 res.model_validator.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_validator.items()}) 

252 res.computed_fields.update({k: v.bind_to_cls(model_dc) for k, v in existing.computed_fields.items()}) 

253 

254 for var_name, var_value in vars(model_dc).items(): 

255 if isinstance(var_value, PydanticDescriptorProxy): 

256 info = var_value.decorator_info 

257 if isinstance(info, ValidatorDecoratorInfo): 

258 res.validator[var_name] = Decorator.build( 

259 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

260 ) 

261 elif isinstance(info, FieldValidatorDecoratorInfo): 

262 res.field_validator[var_name] = Decorator.build( 

263 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

264 ) 

265 elif isinstance(info, RootValidatorDecoratorInfo): 

266 res.root_validator[var_name] = Decorator.build( 

267 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

268 ) 

269 elif isinstance(info, FieldSerializerDecoratorInfo): 

270 # check whether a serializer function is already registered for fields 

271 for field_serializer_decorator in res.field_serializer.values(): 

272 # check that each field has at most one serializer function. 

273 # serializer functions for the same field in subclasses are allowed, 

274 # and are treated as overrides 

275 if field_serializer_decorator.cls_var_name == var_name: 

276 continue 

277 for f in info.fields: 

278 if f in field_serializer_decorator.info.fields: 

279 raise PydanticUserError( 

280 'Multiple field serializer functions were defined ' 

281 f'for field {f!r}, this is not allowed.', 

282 code='multiple-field-serializers', 

283 ) 

284 res.field_serializer[var_name] = Decorator.build( 

285 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

286 ) 

287 elif isinstance(info, ModelValidatorDecoratorInfo): 

288 res.model_validator[var_name] = Decorator.build( 

289 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

290 ) 

291 elif isinstance(info, ModelSerializerDecoratorInfo): 

292 res.model_serializer[var_name] = Decorator.build( 

293 model_dc, cls_var_name=var_name, shim=var_value.shim, info=info 

294 ) 

295 else: 

296 isinstance(var_value, ComputedFieldInfo) 

297 res.computed_fields[var_name] = Decorator.build( 

298 model_dc, cls_var_name=var_name, shim=None, info=info 

299 ) 

300 setattr(model_dc, var_name, var_value.wrapped) 

301 return res 

302 

303 

304def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool: 

305 """ 

306 Look at a field or model validator function and determine if it whether it takes an info argument. 

307 

308 An error is raised if the function has an invalid signature. 

309 

310 Args: 

311 validator: The validator function to inspect. 

312 mode: The proposed validator mode. 

313 

314 Returns: 

315 Whether the validator takes an info argument. 

316 """ 

317 sig = signature(validator) 

318 n_positional = count_positional_params(sig) 

319 if mode == 'wrap': 

320 if n_positional == 3: 

321 return True 

322 elif n_positional == 2: 

323 return False 

324 else: 

325 assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain" 

326 if n_positional == 2: 

327 return True 

328 elif n_positional == 1: 

329 return False 

330 

331 raise PydanticUserError( 

332 f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}', 

333 code='field-validator-signature', 

334 ) 

335 

336 

337def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]: 

338 """ 

339 Look at a field serializer function and determine if it is a field serializer, 

340 and whether it takes an info argument. 

341 

342 An error is raised if the function has an invalid signature. 

343 

344 Args: 

345 serializer: The serializer function to inspect. 

346 mode: The serializer mode, either 'plain' or 'wrap'. 

347 

348 Returns: 

349 Tuple of (is_field_serializer, info_arg) 

350 """ 

351 sig = signature(serializer) 

352 

353 first = next(iter(sig.parameters.values()), None) 

354 is_field_serializer = first is not None and first.name == 'self' 

355 

356 n_positional = count_positional_params(sig) 

357 if is_field_serializer: 

358 # -1 to correct for self parameter 

359 info_arg = _serializer_info_arg(mode, n_positional - 1) 

360 else: 

361 info_arg = _serializer_info_arg(mode, n_positional) 

362 

363 if info_arg is None: 

364 raise PydanticUserError( 

365 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}', 

366 code='field-serializer-signature', 

367 ) 

368 else: 

369 return is_field_serializer, info_arg 

370 

371 

372def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool: 

373 """ 

374 Look at a serializer function used via `Annotated` and determine whether it takes an info argument. 

375 

376 An error is raised if the function has an invalid signature. 

377 

378 Args: 

379 serializer: The serializer function to check. 

380 mode: The serializer mode, either 'plain' or 'wrap'. 

381 

382 Returns: 

383 info_arg 

384 """ 

385 sig = signature(serializer) 

386 info_arg = _serializer_info_arg(mode, count_positional_params(sig)) 

387 if info_arg is None: 

388 raise PydanticUserError( 

389 f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}', 

390 code='field-serializer-signature', 

391 ) 

392 else: 

393 return info_arg 

394 

395 

396def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool: 

397 """ 

398 Look at a model serializer function and determine whether it takes an info argument. 

399 

400 An error is raised if the function has an invalid signature. 

401 

402 Args: 

403 serializer: The serializer function to check. 

404 mode: The serializer mode, either 'plain' or 'wrap'. 

405 

406 Returns: 

407 `info_arg` - whether the function expects an info argument 

408 """ 

409 

410 if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer): 

411 raise PydanticUserError( 

412 '`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method' 

413 ) 

414 

415 sig = signature(serializer) 

416 info_arg = _serializer_info_arg(mode, count_positional_params(sig)) 

417 if info_arg is None: 

418 raise PydanticUserError( 

419 f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}', 

420 code='model-serializer-signature', 

421 ) 

422 else: 

423 return info_arg 

424 

425 

426def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None: 

427 if mode == 'plain': 

428 if n_positional == 1: 

429 # (__input_value: Any) -> Any 

430 return False 

431 elif n_positional == 2: 

432 # (__model: Any, __input_value: Any) -> Any 

433 return True 

434 else: 

435 assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'" 

436 if n_positional == 2: 

437 # (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any 

438 return False 

439 elif n_positional == 3: 

440 # (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any 

441 return True 

442 

443 return None 

444 

445 

446AnyDecoratorCallable: TypeAlias = ( 

447 'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]' 

448) 

449 

450 

451def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool: 

452 sig = signature(unwrap_wrapped_function(function)) 

453 first = next(iter(sig.parameters.values()), None) 

454 if first and first.name == 'self': 

455 return True 

456 return False 

457 

458 

459def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any: 

460 if not isinstance( 

461 unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod 

462 ) and _is_classmethod_from_sig(function): 

463 return classmethod(function) # type: ignore[arg-type] 

464 return function 

465 

466 

467def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool: 

468 sig = signature(unwrap_wrapped_function(function)) 

469 first = next(iter(sig.parameters.values()), None) 

470 if first and first.name == 'cls': 

471 return True 

472 return False 

473 

474 

475def unwrap_wrapped_function( 

476 func: Any, 

477 *, 

478 unwrap_class_static_method: bool = True, 

479) -> Any: 

480 """ 

481 Recursively unwraps a wrapped function until the underlying function is reached. 

482 This handles functools.partial, functools.partialmethod, staticmethod and classmethod. 

483 

484 Args: 

485 func: The function to unwrap. 

486 unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod 

487 decorators. If False, only unwrap partial and partialmethod decorators. 

488 

489 Returns: 

490 The underlying function of the wrapped function. 

491 """ 

492 all: tuple[Any, ...] 

493 if unwrap_class_static_method: 

494 all = ( 

495 staticmethod, 

496 classmethod, 

497 partial, 

498 partialmethod, 

499 ) 

500 else: 

501 all = partial, partialmethod 

502 

503 while isinstance(func, all): 

504 if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)): 

505 func = func.__func__ 

506 elif isinstance(func, (partial, partialmethod)): 

507 func = func.func 

508 

509 return func 

510 

511 

512def count_positional_params(sig: Signature) -> int: 

513 return sum(1 for param in sig.parameters.values() if can_be_positional(param)) 

514 

515 

516def can_be_positional(param: Parameter) -> bool: 

517 return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) 

518 

519 

520def ensure_property(f: Any) -> Any: 

521 """ 

522 Ensure that a function is a `property` or `cached_property`, or is a valid descriptor. 

523 

524 Args: 

525 f: The function to check. 

526 

527 Returns: 

528 The function, or a `property` or `cached_property` instance wrapping the function. 

529 """ 

530 

531 if ismethoddescriptor(f) or isdatadescriptor(f): 

532 return f 

533 else: 

534 return property(f)