1"""
2The main purpose is to enhance stdlib dataclasses by adding validation
3A pydantic dataclass can be generated from scratch or from a stdlib one.
4
5Behind the scene, a pydantic dataclass is just like a regular one on which we attach
6a `BaseModel` and magic methods to trigger the validation of the data.
7`__init__` and `__post_init__` are hence overridden and have extra logic to be
8able to validate input data.
9
10When a pydantic dataclass is generated from scratch, it's just a plain dataclass
11with validation triggered at initialization
12
13The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g.
14
15```py
16@dataclasses.dataclass
17class M:
18 x: int
19
20ValidatedM = pydantic.dataclasses.dataclass(M)
21```
22
23We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one!
24
25```py
26assert isinstance(ValidatedM(x=1), M)
27assert ValidatedM(x=1) == M(x=1)
28```
29
30This means we **don't want to create a new dataclass that inherits from it**
31The trick is to create a wrapper around `M` that will act as a proxy to trigger
32validation without altering default `M` behaviour.
33"""
34import copy
35import dataclasses
36import sys
37from contextlib import contextmanager
38from functools import wraps
39
40try:
41 from functools import cached_property
42except ImportError:
43 # cached_property available only for python3.8+
44 pass
45
46from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
47
48from typing_extensions import dataclass_transform
49
50from pydantic.v1.class_validators import gather_all_validators
51from pydantic.v1.config import BaseConfig, ConfigDict, Extra, get_config
52from pydantic.v1.error_wrappers import ValidationError
53from pydantic.v1.errors import DataclassTypeError
54from pydantic.v1.fields import Field, FieldInfo, Required, Undefined
55from pydantic.v1.main import create_model, validate_model
56from pydantic.v1.utils import ClassAttribute
57
58if TYPE_CHECKING:
59 from pydantic.v1.main import BaseModel
60 from pydantic.v1.typing import CallableGenerator, NoArgAnyCallable
61
62 DataclassT = TypeVar('DataclassT', bound='Dataclass')
63
64 DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy']
65
66 class Dataclass:
67 # stdlib attributes
68 __dataclass_fields__: ClassVar[Dict[str, Any]]
69 __dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
70 __post_init__: ClassVar[Callable[..., None]]
71
72 # Added by pydantic
73 __pydantic_run_validation__: ClassVar[bool]
74 __post_init_post_parse__: ClassVar[Callable[..., None]]
75 __pydantic_initialised__: ClassVar[bool]
76 __pydantic_model__: ClassVar[Type[BaseModel]]
77 __pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]]
78 __pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value
79
80 def __init__(self, *args: object, **kwargs: object) -> None:
81 pass
82
83 @classmethod
84 def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
85 pass
86
87 @classmethod
88 def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
89 pass
90
91
92__all__ = [
93 'dataclass',
94 'set_validation',
95 'create_pydantic_model_from_dataclass',
96 'is_builtin_dataclass',
97 'make_dataclass_validator',
98]
99
100_T = TypeVar('_T')
101
102if sys.version_info >= (3, 10):
103
104 @dataclass_transform(field_specifiers=(dataclasses.field, Field))
105 @overload
106 def dataclass(
107 *,
108 init: bool = True,
109 repr: bool = True,
110 eq: bool = True,
111 order: bool = False,
112 unsafe_hash: bool = False,
113 frozen: bool = False,
114 config: Union[ConfigDict, Type[object], None] = None,
115 validate_on_init: Optional[bool] = None,
116 use_proxy: Optional[bool] = None,
117 kw_only: bool = ...,
118 ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
119 ...
120
121 @dataclass_transform(field_specifiers=(dataclasses.field, Field))
122 @overload
123 def dataclass(
124 _cls: Type[_T],
125 *,
126 init: bool = True,
127 repr: bool = True,
128 eq: bool = True,
129 order: bool = False,
130 unsafe_hash: bool = False,
131 frozen: bool = False,
132 config: Union[ConfigDict, Type[object], None] = None,
133 validate_on_init: Optional[bool] = None,
134 use_proxy: Optional[bool] = None,
135 kw_only: bool = ...,
136 ) -> 'DataclassClassOrWrapper':
137 ...
138
139else:
140
141 @dataclass_transform(field_specifiers=(dataclasses.field, Field))
142 @overload
143 def dataclass(
144 *,
145 init: bool = True,
146 repr: bool = True,
147 eq: bool = True,
148 order: bool = False,
149 unsafe_hash: bool = False,
150 frozen: bool = False,
151 config: Union[ConfigDict, Type[object], None] = None,
152 validate_on_init: Optional[bool] = None,
153 use_proxy: Optional[bool] = None,
154 ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
155 ...
156
157 @dataclass_transform(field_specifiers=(dataclasses.field, Field))
158 @overload
159 def dataclass(
160 _cls: Type[_T],
161 *,
162 init: bool = True,
163 repr: bool = True,
164 eq: bool = True,
165 order: bool = False,
166 unsafe_hash: bool = False,
167 frozen: bool = False,
168 config: Union[ConfigDict, Type[object], None] = None,
169 validate_on_init: Optional[bool] = None,
170 use_proxy: Optional[bool] = None,
171 ) -> 'DataclassClassOrWrapper':
172 ...
173
174
175@dataclass_transform(field_specifiers=(dataclasses.field, Field))
176def dataclass(
177 _cls: Optional[Type[_T]] = None,
178 *,
179 init: bool = True,
180 repr: bool = True,
181 eq: bool = True,
182 order: bool = False,
183 unsafe_hash: bool = False,
184 frozen: bool = False,
185 config: Union[ConfigDict, Type[object], None] = None,
186 validate_on_init: Optional[bool] = None,
187 use_proxy: Optional[bool] = None,
188 kw_only: bool = False,
189) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']:
190 """
191 Like the python standard lib dataclasses but with type validation.
192 The result is either a pydantic dataclass that will validate input data
193 or a wrapper that will trigger validation around a stdlib dataclass
194 to avoid modifying it directly
195 """
196 the_config = get_config(config)
197
198 def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
199 should_use_proxy = (
200 use_proxy
201 if use_proxy is not None
202 else (
203 is_builtin_dataclass(cls)
204 and (cls.__bases__[0] is object or set(dir(cls)) == set(dir(cls.__bases__[0])))
205 )
206 )
207 if should_use_proxy:
208 dc_cls_doc = ''
209 dc_cls = DataclassProxy(cls)
210 default_validate_on_init = False
211 else:
212 dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass
213 if sys.version_info >= (3, 10):
214 dc_cls = dataclasses.dataclass(
215 cls,
216 init=init,
217 repr=repr,
218 eq=eq,
219 order=order,
220 unsafe_hash=unsafe_hash,
221 frozen=frozen,
222 kw_only=kw_only,
223 )
224 else:
225 dc_cls = dataclasses.dataclass( # type: ignore
226 cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
227 )
228 default_validate_on_init = True
229
230 should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init
231 _add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc)
232 dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls})
233 return dc_cls
234
235 if _cls is None:
236 return wrap
237
238 return wrap(_cls)
239
240
241@contextmanager
242def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]:
243 original_run_validation = cls.__pydantic_run_validation__
244 try:
245 cls.__pydantic_run_validation__ = value
246 yield cls
247 finally:
248 cls.__pydantic_run_validation__ = original_run_validation
249
250
251class DataclassProxy:
252 __slots__ = '__dataclass__'
253
254 def __init__(self, dc_cls: Type['Dataclass']) -> None:
255 object.__setattr__(self, '__dataclass__', dc_cls)
256
257 def __call__(self, *args: Any, **kwargs: Any) -> Any:
258 with set_validation(self.__dataclass__, True):
259 return self.__dataclass__(*args, **kwargs)
260
261 def __getattr__(self, name: str) -> Any:
262 return getattr(self.__dataclass__, name)
263
264 def __setattr__(self, __name: str, __value: Any) -> None:
265 return setattr(self.__dataclass__, __name, __value)
266
267 def __instancecheck__(self, instance: Any) -> bool:
268 return isinstance(instance, self.__dataclass__)
269
270 def __copy__(self) -> 'DataclassProxy':
271 return DataclassProxy(copy.copy(self.__dataclass__))
272
273 def __deepcopy__(self, memo: Any) -> 'DataclassProxy':
274 return DataclassProxy(copy.deepcopy(self.__dataclass__, memo))
275
276
277def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity)
278 dc_cls: Type['Dataclass'],
279 config: Type[BaseConfig],
280 validate_on_init: bool,
281 dc_cls_doc: str,
282) -> None:
283 """
284 We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass
285 it won't even exist (code is generated on the fly by `dataclasses`)
286 By default, we run validation after `__init__` or `__post_init__` if defined
287 """
288 init = dc_cls.__init__
289
290 @wraps(init)
291 def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
292 if config.extra == Extra.ignore:
293 init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
294
295 elif config.extra == Extra.allow:
296 for k, v in kwargs.items():
297 self.__dict__.setdefault(k, v)
298 init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
299
300 else:
301 init(self, *args, **kwargs)
302
303 if hasattr(dc_cls, '__post_init__'):
304 try:
305 post_init = dc_cls.__post_init__.__wrapped__ # type: ignore[attr-defined]
306 except AttributeError:
307 post_init = dc_cls.__post_init__
308
309 @wraps(post_init)
310 def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
311 if config.post_init_call == 'before_validation':
312 post_init(self, *args, **kwargs)
313
314 if self.__class__.__pydantic_run_validation__:
315 self.__pydantic_validate_values__()
316 if hasattr(self, '__post_init_post_parse__'):
317 self.__post_init_post_parse__(*args, **kwargs)
318
319 if config.post_init_call == 'after_validation':
320 post_init(self, *args, **kwargs)
321
322 setattr(dc_cls, '__init__', handle_extra_init)
323 setattr(dc_cls, '__post_init__', new_post_init)
324
325 else:
326
327 @wraps(init)
328 def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
329 handle_extra_init(self, *args, **kwargs)
330
331 if self.__class__.__pydantic_run_validation__:
332 self.__pydantic_validate_values__()
333
334 if hasattr(self, '__post_init_post_parse__'):
335 # We need to find again the initvars. To do that we use `__dataclass_fields__` instead of
336 # public method `dataclasses.fields`
337
338 # get all initvars and their default values
339 initvars_and_values: Dict[str, Any] = {}
340 for i, f in enumerate(self.__class__.__dataclass_fields__.values()):
341 if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined]
342 try:
343 # set arg value by default
344 initvars_and_values[f.name] = args[i]
345 except IndexError:
346 initvars_and_values[f.name] = kwargs.get(f.name, f.default)
347
348 self.__post_init_post_parse__(**initvars_and_values)
349
350 setattr(dc_cls, '__init__', new_init)
351
352 setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init))
353 setattr(dc_cls, '__pydantic_initialised__', False)
354 setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc))
355 setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values)
356 setattr(dc_cls, '__validate__', classmethod(_validate_dataclass))
357 setattr(dc_cls, '__get_validators__', classmethod(_get_validators))
358
359 if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen:
360 setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr)
361
362
363def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator':
364 yield cls.__validate__
365
366
367def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
368 with set_validation(cls, True):
369 if isinstance(v, cls):
370 v.__pydantic_validate_values__()
371 return v
372 elif isinstance(v, (list, tuple)):
373 return cls(*v)
374 elif isinstance(v, dict):
375 return cls(**v)
376 else:
377 raise DataclassTypeError(class_name=cls.__name__)
378
379
380def create_pydantic_model_from_dataclass(
381 dc_cls: Type['Dataclass'],
382 config: Type[Any] = BaseConfig,
383 dc_cls_doc: Optional[str] = None,
384) -> Type['BaseModel']:
385 field_definitions: Dict[str, Any] = {}
386 for field in dataclasses.fields(dc_cls):
387 default: Any = Undefined
388 default_factory: Optional['NoArgAnyCallable'] = None
389 field_info: FieldInfo
390
391 if field.default is not dataclasses.MISSING:
392 default = field.default
393 elif field.default_factory is not dataclasses.MISSING:
394 default_factory = field.default_factory
395 else:
396 default = Required
397
398 if isinstance(default, FieldInfo):
399 field_info = default
400 dc_cls.__pydantic_has_field_info_default__ = True
401 else:
402 field_info = Field(default=default, default_factory=default_factory, **field.metadata)
403
404 field_definitions[field.name] = (field.type, field_info)
405
406 validators = gather_all_validators(dc_cls)
407 model: Type['BaseModel'] = create_model(
408 dc_cls.__name__,
409 __config__=config,
410 __module__=dc_cls.__module__,
411 __validators__=validators,
412 __cls_kwargs__={'__resolve_forward_refs__': False},
413 **field_definitions,
414 )
415 model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or ''
416 return model
417
418
419if sys.version_info >= (3, 8):
420
421 def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
422 return isinstance(getattr(type(obj), k, None), cached_property)
423
424else:
425
426 def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
427 return False
428
429
430def _dataclass_validate_values(self: 'Dataclass') -> None:
431 # validation errors can occur if this function is called twice on an already initialised dataclass.
432 # for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property
433 if getattr(self, '__pydantic_initialised__'):
434 return
435 if getattr(self, '__pydantic_has_field_info_default__', False):
436 # We need to remove `FieldInfo` values since they are not valid as input
437 # It's ok to do that because they are obviously the default values!
438 input_data = {
439 k: v
440 for k, v in self.__dict__.items()
441 if not (isinstance(v, FieldInfo) or _is_field_cached_property(self, k))
442 }
443 else:
444 input_data = {k: v for k, v in self.__dict__.items() if not _is_field_cached_property(self, k)}
445 d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
446 if validation_error:
447 raise validation_error
448 self.__dict__.update(d)
449 object.__setattr__(self, '__pydantic_initialised__', True)
450
451
452def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None:
453 if self.__pydantic_initialised__:
454 d = dict(self.__dict__)
455 d.pop(name, None)
456 known_field = self.__pydantic_model__.__fields__.get(name, None)
457 if known_field:
458 value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
459 if error_:
460 raise ValidationError([error_], self.__class__)
461
462 object.__setattr__(self, name, value)
463
464
465def is_builtin_dataclass(_cls: Type[Any]) -> bool:
466 """
467 Whether a class is a stdlib dataclass
468 (useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass)
469
470 we check that
471 - `_cls` is a dataclass
472 - `_cls` is not a processed pydantic dataclass (with a basemodel attached)
473 - `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass
474 e.g.
475 ```
476 @dataclasses.dataclass
477 class A:
478 x: int
479
480 @pydantic.dataclasses.dataclass
481 class B(A):
482 y: int
483 ```
484 In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
485 which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
486 """
487 return (
488 dataclasses.is_dataclass(_cls)
489 and not hasattr(_cls, '__pydantic_model__')
490 and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
491 )
492
493
494def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator':
495 """
496 Create a pydantic.dataclass from a builtin dataclass to add type validation
497 and yield the validators
498 It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
499 """
500 yield from _get_validators(dataclass(dc_cls, config=config, use_proxy=True))