Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/pydantic/v1/dataclasses.py: 17%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

218 statements  

1""" 

2The main purpose is to enhance stdlib dataclasses by adding validation 

3A pydantic dataclass can be generated from scratch or from a stdlib one. 

4 

5Behind the scene, a pydantic dataclass is just like a regular one on which we attach 

6a `BaseModel` and magic methods to trigger the validation of the data. 

7`__init__` and `__post_init__` are hence overridden and have extra logic to be 

8able to validate input data. 

9 

10When a pydantic dataclass is generated from scratch, it's just a plain dataclass 

11with validation triggered at initialization 

12 

13The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g. 

14 

15```py 

16@dataclasses.dataclass 

17class M: 

18 x: int 

19 

20ValidatedM = pydantic.dataclasses.dataclass(M) 

21``` 

22 

23We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one! 

24 

25```py 

26assert isinstance(ValidatedM(x=1), M) 

27assert ValidatedM(x=1) == M(x=1) 

28``` 

29 

30This means we **don't want to create a new dataclass that inherits from it** 

31The trick is to create a wrapper around `M` that will act as a proxy to trigger 

32validation without altering default `M` behaviour. 

33""" 

34import copy 

35import dataclasses 

36import sys 

37from contextlib import contextmanager 

38from functools import wraps 

39 

40try: 

41 from functools import cached_property 

42except ImportError: 

43 # cached_property available only for python3.8+ 

44 pass 

45 

46from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload 

47 

48from typing_extensions import dataclass_transform 

49 

50from pydantic.v1.class_validators import gather_all_validators 

51from pydantic.v1.config import BaseConfig, ConfigDict, Extra, get_config 

52from pydantic.v1.error_wrappers import ValidationError 

53from pydantic.v1.errors import DataclassTypeError 

54from pydantic.v1.fields import Field, FieldInfo, Required, Undefined 

55from pydantic.v1.main import create_model, validate_model 

56from pydantic.v1.utils import ClassAttribute 

57 

58if TYPE_CHECKING: 

59 from pydantic.v1.main import BaseModel 

60 from pydantic.v1.typing import CallableGenerator, NoArgAnyCallable 

61 

62 DataclassT = TypeVar('DataclassT', bound='Dataclass') 

63 

64 DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy'] 

65 

66 class Dataclass: 

67 # stdlib attributes 

68 __dataclass_fields__: ClassVar[Dict[str, Any]] 

69 __dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams` 

70 __post_init__: ClassVar[Callable[..., None]] 

71 

72 # Added by pydantic 

73 __pydantic_run_validation__: ClassVar[bool] 

74 __post_init_post_parse__: ClassVar[Callable[..., None]] 

75 __pydantic_initialised__: ClassVar[bool] 

76 __pydantic_model__: ClassVar[Type[BaseModel]] 

77 __pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]] 

78 __pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value 

79 

80 def __init__(self, *args: object, **kwargs: object) -> None: 

81 pass 

82 

83 @classmethod 

84 def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator': 

85 pass 

86 

87 @classmethod 

88 def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT': 

89 pass 

90 

91 

92__all__ = [ 

93 'dataclass', 

94 'set_validation', 

95 'create_pydantic_model_from_dataclass', 

96 'is_builtin_dataclass', 

97 'make_dataclass_validator', 

98] 

99 

100_T = TypeVar('_T') 

101 

102if sys.version_info >= (3, 10): 

103 

104 @dataclass_transform(field_specifiers=(dataclasses.field, Field)) 

105 @overload 

106 def dataclass( 

107 *, 

108 init: bool = True, 

109 repr: bool = True, 

110 eq: bool = True, 

111 order: bool = False, 

112 unsafe_hash: bool = False, 

113 frozen: bool = False, 

114 config: Union[ConfigDict, Type[object], None] = None, 

115 validate_on_init: Optional[bool] = None, 

116 use_proxy: Optional[bool] = None, 

117 kw_only: bool = ..., 

118 ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: 

119 ... 

120 

121 @dataclass_transform(field_specifiers=(dataclasses.field, Field)) 

122 @overload 

123 def dataclass( 

124 _cls: Type[_T], 

125 *, 

126 init: bool = True, 

127 repr: bool = True, 

128 eq: bool = True, 

129 order: bool = False, 

130 unsafe_hash: bool = False, 

131 frozen: bool = False, 

132 config: Union[ConfigDict, Type[object], None] = None, 

133 validate_on_init: Optional[bool] = None, 

134 use_proxy: Optional[bool] = None, 

135 kw_only: bool = ..., 

136 ) -> 'DataclassClassOrWrapper': 

137 ... 

138 

139else: 

140 

141 @dataclass_transform(field_specifiers=(dataclasses.field, Field)) 

142 @overload 

143 def dataclass( 

144 *, 

145 init: bool = True, 

146 repr: bool = True, 

147 eq: bool = True, 

148 order: bool = False, 

149 unsafe_hash: bool = False, 

150 frozen: bool = False, 

151 config: Union[ConfigDict, Type[object], None] = None, 

152 validate_on_init: Optional[bool] = None, 

153 use_proxy: Optional[bool] = None, 

154 ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: 

155 ... 

156 

157 @dataclass_transform(field_specifiers=(dataclasses.field, Field)) 

158 @overload 

159 def dataclass( 

160 _cls: Type[_T], 

161 *, 

162 init: bool = True, 

163 repr: bool = True, 

164 eq: bool = True, 

165 order: bool = False, 

166 unsafe_hash: bool = False, 

167 frozen: bool = False, 

168 config: Union[ConfigDict, Type[object], None] = None, 

169 validate_on_init: Optional[bool] = None, 

170 use_proxy: Optional[bool] = None, 

171 ) -> 'DataclassClassOrWrapper': 

172 ... 

173 

174 

175@dataclass_transform(field_specifiers=(dataclasses.field, Field)) 

176def dataclass( 

177 _cls: Optional[Type[_T]] = None, 

178 *, 

179 init: bool = True, 

180 repr: bool = True, 

181 eq: bool = True, 

182 order: bool = False, 

183 unsafe_hash: bool = False, 

184 frozen: bool = False, 

185 config: Union[ConfigDict, Type[object], None] = None, 

186 validate_on_init: Optional[bool] = None, 

187 use_proxy: Optional[bool] = None, 

188 kw_only: bool = False, 

189) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']: 

190 """ 

191 Like the python standard lib dataclasses but with type validation. 

192 The result is either a pydantic dataclass that will validate input data 

193 or a wrapper that will trigger validation around a stdlib dataclass 

194 to avoid modifying it directly 

195 """ 

196 the_config = get_config(config) 

197 

198 def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper': 

199 should_use_proxy = ( 

200 use_proxy 

201 if use_proxy is not None 

202 else ( 

203 is_builtin_dataclass(cls) 

204 and (cls.__bases__[0] is object or set(dir(cls)) == set(dir(cls.__bases__[0]))) 

205 ) 

206 ) 

207 if should_use_proxy: 

208 dc_cls_doc = '' 

209 dc_cls = DataclassProxy(cls) 

210 default_validate_on_init = False 

211 else: 

212 dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass 

213 if sys.version_info >= (3, 10): 

214 dc_cls = dataclasses.dataclass( 

215 cls, 

216 init=init, 

217 repr=repr, 

218 eq=eq, 

219 order=order, 

220 unsafe_hash=unsafe_hash, 

221 frozen=frozen, 

222 kw_only=kw_only, 

223 ) 

224 else: 

225 dc_cls = dataclasses.dataclass( # type: ignore 

226 cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen 

227 ) 

228 default_validate_on_init = True 

229 

230 should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init 

231 _add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc) 

232 dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls}) 

233 return dc_cls 

234 

235 if _cls is None: 

236 return wrap 

237 

238 return wrap(_cls) 

239 

240 

241@contextmanager 

242def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]: 

243 original_run_validation = cls.__pydantic_run_validation__ 

244 try: 

245 cls.__pydantic_run_validation__ = value 

246 yield cls 

247 finally: 

248 cls.__pydantic_run_validation__ = original_run_validation 

249 

250 

251class DataclassProxy: 

252 __slots__ = '__dataclass__' 

253 

254 def __init__(self, dc_cls: Type['Dataclass']) -> None: 

255 object.__setattr__(self, '__dataclass__', dc_cls) 

256 

257 def __call__(self, *args: Any, **kwargs: Any) -> Any: 

258 with set_validation(self.__dataclass__, True): 

259 return self.__dataclass__(*args, **kwargs) 

260 

261 def __getattr__(self, name: str) -> Any: 

262 return getattr(self.__dataclass__, name) 

263 

264 def __setattr__(self, __name: str, __value: Any) -> None: 

265 return setattr(self.__dataclass__, __name, __value) 

266 

267 def __instancecheck__(self, instance: Any) -> bool: 

268 return isinstance(instance, self.__dataclass__) 

269 

270 def __copy__(self) -> 'DataclassProxy': 

271 return DataclassProxy(copy.copy(self.__dataclass__)) 

272 

273 def __deepcopy__(self, memo: Any) -> 'DataclassProxy': 

274 return DataclassProxy(copy.deepcopy(self.__dataclass__, memo)) 

275 

276 

277def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity) 

278 dc_cls: Type['Dataclass'], 

279 config: Type[BaseConfig], 

280 validate_on_init: bool, 

281 dc_cls_doc: str, 

282) -> None: 

283 """ 

284 We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass 

285 it won't even exist (code is generated on the fly by `dataclasses`) 

286 By default, we run validation after `__init__` or `__post_init__` if defined 

287 """ 

288 init = dc_cls.__init__ 

289 

290 @wraps(init) 

291 def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: 

292 if config.extra == Extra.ignore: 

293 init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) 

294 

295 elif config.extra == Extra.allow: 

296 for k, v in kwargs.items(): 

297 self.__dict__.setdefault(k, v) 

298 init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) 

299 

300 else: 

301 init(self, *args, **kwargs) 

302 

303 if hasattr(dc_cls, '__post_init__'): 

304 try: 

305 post_init = dc_cls.__post_init__.__wrapped__ # type: ignore[attr-defined] 

306 except AttributeError: 

307 post_init = dc_cls.__post_init__ 

308 

309 @wraps(post_init) 

310 def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: 

311 if config.post_init_call == 'before_validation': 

312 post_init(self, *args, **kwargs) 

313 

314 if self.__class__.__pydantic_run_validation__: 

315 self.__pydantic_validate_values__() 

316 if hasattr(self, '__post_init_post_parse__'): 

317 self.__post_init_post_parse__(*args, **kwargs) 

318 

319 if config.post_init_call == 'after_validation': 

320 post_init(self, *args, **kwargs) 

321 

322 setattr(dc_cls, '__init__', handle_extra_init) 

323 setattr(dc_cls, '__post_init__', new_post_init) 

324 

325 else: 

326 

327 @wraps(init) 

328 def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: 

329 handle_extra_init(self, *args, **kwargs) 

330 

331 if self.__class__.__pydantic_run_validation__: 

332 self.__pydantic_validate_values__() 

333 

334 if hasattr(self, '__post_init_post_parse__'): 

335 # We need to find again the initvars. To do that we use `__dataclass_fields__` instead of 

336 # public method `dataclasses.fields` 

337 

338 # get all initvars and their default values 

339 initvars_and_values: Dict[str, Any] = {} 

340 for i, f in enumerate(self.__class__.__dataclass_fields__.values()): 

341 if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined] 

342 try: 

343 # set arg value by default 

344 initvars_and_values[f.name] = args[i] 

345 except IndexError: 

346 initvars_and_values[f.name] = kwargs.get(f.name, f.default) 

347 

348 self.__post_init_post_parse__(**initvars_and_values) 

349 

350 setattr(dc_cls, '__init__', new_init) 

351 

352 setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init)) 

353 setattr(dc_cls, '__pydantic_initialised__', False) 

354 setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc)) 

355 setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values) 

356 setattr(dc_cls, '__validate__', classmethod(_validate_dataclass)) 

357 setattr(dc_cls, '__get_validators__', classmethod(_get_validators)) 

358 

359 if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen: 

360 setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr) 

361 

362 

363def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator': 

364 yield cls.__validate__ 

365 

366 

367def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT': 

368 with set_validation(cls, True): 

369 if isinstance(v, cls): 

370 v.__pydantic_validate_values__() 

371 return v 

372 elif isinstance(v, (list, tuple)): 

373 return cls(*v) 

374 elif isinstance(v, dict): 

375 return cls(**v) 

376 else: 

377 raise DataclassTypeError(class_name=cls.__name__) 

378 

379 

380def create_pydantic_model_from_dataclass( 

381 dc_cls: Type['Dataclass'], 

382 config: Type[Any] = BaseConfig, 

383 dc_cls_doc: Optional[str] = None, 

384) -> Type['BaseModel']: 

385 field_definitions: Dict[str, Any] = {} 

386 for field in dataclasses.fields(dc_cls): 

387 default: Any = Undefined 

388 default_factory: Optional['NoArgAnyCallable'] = None 

389 field_info: FieldInfo 

390 

391 if field.default is not dataclasses.MISSING: 

392 default = field.default 

393 elif field.default_factory is not dataclasses.MISSING: 

394 default_factory = field.default_factory 

395 else: 

396 default = Required 

397 

398 if isinstance(default, FieldInfo): 

399 field_info = default 

400 dc_cls.__pydantic_has_field_info_default__ = True 

401 else: 

402 field_info = Field(default=default, default_factory=default_factory, **field.metadata) 

403 

404 field_definitions[field.name] = (field.type, field_info) 

405 

406 validators = gather_all_validators(dc_cls) 

407 model: Type['BaseModel'] = create_model( 

408 dc_cls.__name__, 

409 __config__=config, 

410 __module__=dc_cls.__module__, 

411 __validators__=validators, 

412 __cls_kwargs__={'__resolve_forward_refs__': False}, 

413 **field_definitions, 

414 ) 

415 model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or '' 

416 return model 

417 

418 

419if sys.version_info >= (3, 8): 

420 

421 def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool: 

422 return isinstance(getattr(type(obj), k, None), cached_property) 

423 

424else: 

425 

426 def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool: 

427 return False 

428 

429 

430def _dataclass_validate_values(self: 'Dataclass') -> None: 

431 # validation errors can occur if this function is called twice on an already initialised dataclass. 

432 # for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property 

433 if getattr(self, '__pydantic_initialised__'): 

434 return 

435 if getattr(self, '__pydantic_has_field_info_default__', False): 

436 # We need to remove `FieldInfo` values since they are not valid as input 

437 # It's ok to do that because they are obviously the default values! 

438 input_data = { 

439 k: v 

440 for k, v in self.__dict__.items() 

441 if not (isinstance(v, FieldInfo) or _is_field_cached_property(self, k)) 

442 } 

443 else: 

444 input_data = {k: v for k, v in self.__dict__.items() if not _is_field_cached_property(self, k)} 

445 d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__) 

446 if validation_error: 

447 raise validation_error 

448 self.__dict__.update(d) 

449 object.__setattr__(self, '__pydantic_initialised__', True) 

450 

451 

452def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None: 

453 if self.__pydantic_initialised__: 

454 d = dict(self.__dict__) 

455 d.pop(name, None) 

456 known_field = self.__pydantic_model__.__fields__.get(name, None) 

457 if known_field: 

458 value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__) 

459 if error_: 

460 raise ValidationError([error_], self.__class__) 

461 

462 object.__setattr__(self, name, value) 

463 

464 

465def is_builtin_dataclass(_cls: Type[Any]) -> bool: 

466 """ 

467 Whether a class is a stdlib dataclass 

468 (useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass) 

469 

470 we check that 

471 - `_cls` is a dataclass 

472 - `_cls` is not a processed pydantic dataclass (with a basemodel attached) 

473 - `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass 

474 e.g. 

475 ``` 

476 @dataclasses.dataclass 

477 class A: 

478 x: int 

479 

480 @pydantic.dataclasses.dataclass 

481 class B(A): 

482 y: int 

483 ``` 

484 In this case, when we first check `B`, we make an extra check and look at the annotations ('y'), 

485 which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x') 

486 """ 

487 return ( 

488 dataclasses.is_dataclass(_cls) 

489 and not hasattr(_cls, '__pydantic_model__') 

490 and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {}))) 

491 ) 

492 

493 

494def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator': 

495 """ 

496 Create a pydantic.dataclass from a builtin dataclass to add type validation 

497 and yield the validators 

498 It retrieves the parameters of the dataclass and forwards them to the newly created dataclass 

499 """ 

500 yield from _get_validators(dataclass(dc_cls, config=config, use_proxy=True))