1""" 
    2The main purpose is to enhance stdlib dataclasses by adding validation 
    3A pydantic dataclass can be generated from scratch or from a stdlib one. 
    4 
    5Behind the scene, a pydantic dataclass is just like a regular one on which we attach 
    6a `BaseModel` and magic methods to trigger the validation of the data. 
    7`__init__` and `__post_init__` are hence overridden and have extra logic to be 
    8able to validate input data. 
    9 
    10When a pydantic dataclass is generated from scratch, it's just a plain dataclass 
    11with validation triggered at initialization 
    12 
    13The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g. 
    14 
    15```py 
    16@dataclasses.dataclass 
    17class M: 
    18    x: int 
    19 
    20ValidatedM = pydantic.dataclasses.dataclass(M) 
    21``` 
    22 
    23We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one! 
    24 
    25```py 
    26assert isinstance(ValidatedM(x=1), M) 
    27assert ValidatedM(x=1) == M(x=1) 
    28``` 
    29 
    30This means we **don't want to create a new dataclass that inherits from it** 
    31The trick is to create a wrapper around `M` that will act as a proxy to trigger 
    32validation without altering default `M` behaviour. 
    33""" 
    34import copy 
    35import dataclasses 
    36import sys 
    37from contextlib import contextmanager 
    38from functools import wraps 
    39 
    40try: 
    41    from functools import cached_property 
    42except ImportError: 
    43    # cached_property available only for python3.8+ 
    44    pass 
    45 
    46from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload 
    47 
    48from typing_extensions import dataclass_transform 
    49 
    50from pydantic.v1.class_validators import gather_all_validators 
    51from pydantic.v1.config import BaseConfig, ConfigDict, Extra, get_config 
    52from pydantic.v1.error_wrappers import ValidationError 
    53from pydantic.v1.errors import DataclassTypeError 
    54from pydantic.v1.fields import Field, FieldInfo, Required, Undefined 
    55from pydantic.v1.main import create_model, validate_model 
    56from pydantic.v1.utils import ClassAttribute 
    57 
    58if TYPE_CHECKING: 
    59    from pydantic.v1.main import BaseModel 
    60    from pydantic.v1.typing import CallableGenerator, NoArgAnyCallable 
    61 
    62    DataclassT = TypeVar('DataclassT', bound='Dataclass') 
    63 
    64    DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy'] 
    65 
    66    class Dataclass: 
    67        # stdlib attributes 
    68        __dataclass_fields__: ClassVar[Dict[str, Any]] 
    69        __dataclass_params__: ClassVar[Any]  # in reality `dataclasses._DataclassParams` 
    70        __post_init__: ClassVar[Callable[..., None]] 
    71 
    72        # Added by pydantic 
    73        __pydantic_run_validation__: ClassVar[bool] 
    74        __post_init_post_parse__: ClassVar[Callable[..., None]] 
    75        __pydantic_initialised__: ClassVar[bool] 
    76        __pydantic_model__: ClassVar[Type[BaseModel]] 
    77        __pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]] 
    78        __pydantic_has_field_info_default__: ClassVar[bool]  # whether a `pydantic.Field` is used as default value 
    79 
    80        def __init__(self, *args: object, **kwargs: object) -> None: 
    81            pass 
    82 
    83        @classmethod 
    84        def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator': 
    85            pass 
    86 
    87        @classmethod 
    88        def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT': 
    89            pass 
    90 
    91 
    92__all__ = [ 
    93    'dataclass', 
    94    'set_validation', 
    95    'create_pydantic_model_from_dataclass', 
    96    'is_builtin_dataclass', 
    97    'make_dataclass_validator', 
    98] 
    99 
    100_T = TypeVar('_T') 
    101 
    102if sys.version_info >= (3, 10): 
    103 
    104    @dataclass_transform(field_specifiers=(dataclasses.field, Field)) 
    105    @overload 
    106    def dataclass( 
    107        *, 
    108        init: bool = True, 
    109        repr: bool = True, 
    110        eq: bool = True, 
    111        order: bool = False, 
    112        unsafe_hash: bool = False, 
    113        frozen: bool = False, 
    114        config: Union[ConfigDict, Type[object], None] = None, 
    115        validate_on_init: Optional[bool] = None, 
    116        use_proxy: Optional[bool] = None, 
    117        kw_only: bool = ..., 
    118    ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: 
    119        ... 
    120 
    121    @dataclass_transform(field_specifiers=(dataclasses.field, Field)) 
    122    @overload 
    123    def dataclass( 
    124        _cls: Type[_T], 
    125        *, 
    126        init: bool = True, 
    127        repr: bool = True, 
    128        eq: bool = True, 
    129        order: bool = False, 
    130        unsafe_hash: bool = False, 
    131        frozen: bool = False, 
    132        config: Union[ConfigDict, Type[object], None] = None, 
    133        validate_on_init: Optional[bool] = None, 
    134        use_proxy: Optional[bool] = None, 
    135        kw_only: bool = ..., 
    136    ) -> 'DataclassClassOrWrapper': 
    137        ... 
    138 
    139else: 
    140 
    141    @dataclass_transform(field_specifiers=(dataclasses.field, Field)) 
    142    @overload 
    143    def dataclass( 
    144        *, 
    145        init: bool = True, 
    146        repr: bool = True, 
    147        eq: bool = True, 
    148        order: bool = False, 
    149        unsafe_hash: bool = False, 
    150        frozen: bool = False, 
    151        config: Union[ConfigDict, Type[object], None] = None, 
    152        validate_on_init: Optional[bool] = None, 
    153        use_proxy: Optional[bool] = None, 
    154    ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: 
    155        ... 
    156 
    157    @dataclass_transform(field_specifiers=(dataclasses.field, Field)) 
    158    @overload 
    159    def dataclass( 
    160        _cls: Type[_T], 
    161        *, 
    162        init: bool = True, 
    163        repr: bool = True, 
    164        eq: bool = True, 
    165        order: bool = False, 
    166        unsafe_hash: bool = False, 
    167        frozen: bool = False, 
    168        config: Union[ConfigDict, Type[object], None] = None, 
    169        validate_on_init: Optional[bool] = None, 
    170        use_proxy: Optional[bool] = None, 
    171    ) -> 'DataclassClassOrWrapper': 
    172        ... 
    173 
    174 
    175@dataclass_transform(field_specifiers=(dataclasses.field, Field)) 
    176def dataclass( 
    177    _cls: Optional[Type[_T]] = None, 
    178    *, 
    179    init: bool = True, 
    180    repr: bool = True, 
    181    eq: bool = True, 
    182    order: bool = False, 
    183    unsafe_hash: bool = False, 
    184    frozen: bool = False, 
    185    config: Union[ConfigDict, Type[object], None] = None, 
    186    validate_on_init: Optional[bool] = None, 
    187    use_proxy: Optional[bool] = None, 
    188    kw_only: bool = False, 
    189) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']: 
    190    """ 
    191    Like the python standard lib dataclasses but with type validation. 
    192    The result is either a pydantic dataclass that will validate input data 
    193    or a wrapper that will trigger validation around a stdlib dataclass 
    194    to avoid modifying it directly 
    195    """ 
    196    the_config = get_config(config) 
    197 
    198    def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper': 
    199        should_use_proxy = ( 
    200            use_proxy 
    201            if use_proxy is not None 
    202            else ( 
    203                is_builtin_dataclass(cls) 
    204                and (cls.__bases__[0] is object or set(dir(cls)) == set(dir(cls.__bases__[0]))) 
    205            ) 
    206        ) 
    207        if should_use_proxy: 
    208            dc_cls_doc = '' 
    209            dc_cls = DataclassProxy(cls) 
    210            default_validate_on_init = False 
    211        else: 
    212            dc_cls_doc = cls.__doc__ or ''  # needs to be done before generating dataclass 
    213            if sys.version_info >= (3, 10): 
    214                dc_cls = dataclasses.dataclass( 
    215                    cls, 
    216                    init=init, 
    217                    repr=repr, 
    218                    eq=eq, 
    219                    order=order, 
    220                    unsafe_hash=unsafe_hash, 
    221                    frozen=frozen, 
    222                    kw_only=kw_only, 
    223                ) 
    224            else: 
    225                dc_cls = dataclasses.dataclass(  # type: ignore 
    226                    cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen 
    227                ) 
    228            default_validate_on_init = True 
    229 
    230        should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init 
    231        _add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc) 
    232        dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls}) 
    233        return dc_cls 
    234 
    235    if _cls is None: 
    236        return wrap 
    237 
    238    return wrap(_cls) 
    239 
    240 
    241@contextmanager 
    242def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]: 
    243    original_run_validation = cls.__pydantic_run_validation__ 
    244    try: 
    245        cls.__pydantic_run_validation__ = value 
    246        yield cls 
    247    finally: 
    248        cls.__pydantic_run_validation__ = original_run_validation 
    249 
    250 
    251class DataclassProxy: 
    252    __slots__ = '__dataclass__' 
    253 
    254    def __init__(self, dc_cls: Type['Dataclass']) -> None: 
    255        object.__setattr__(self, '__dataclass__', dc_cls) 
    256 
    257    def __call__(self, *args: Any, **kwargs: Any) -> Any: 
    258        with set_validation(self.__dataclass__, True): 
    259            return self.__dataclass__(*args, **kwargs) 
    260 
    261    def __getattr__(self, name: str) -> Any: 
    262        return getattr(self.__dataclass__, name) 
    263 
    264    def __setattr__(self, __name: str, __value: Any) -> None: 
    265        return setattr(self.__dataclass__, __name, __value) 
    266 
    267    def __instancecheck__(self, instance: Any) -> bool: 
    268        return isinstance(instance, self.__dataclass__) 
    269 
    270    def __copy__(self) -> 'DataclassProxy': 
    271        return DataclassProxy(copy.copy(self.__dataclass__)) 
    272 
    273    def __deepcopy__(self, memo: Any) -> 'DataclassProxy': 
    274        return DataclassProxy(copy.deepcopy(self.__dataclass__, memo)) 
    275 
    276 
    277def _add_pydantic_validation_attributes(  # noqa: C901 (ignore complexity) 
    278    dc_cls: Type['Dataclass'], 
    279    config: Type[BaseConfig], 
    280    validate_on_init: bool, 
    281    dc_cls_doc: str, 
    282) -> None: 
    283    """ 
    284    We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass 
    285    it won't even exist (code is generated on the fly by `dataclasses`) 
    286    By default, we run validation after `__init__` or `__post_init__` if defined 
    287    """ 
    288    init = dc_cls.__init__ 
    289 
    290    @wraps(init) 
    291    def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: 
    292        if config.extra == Extra.ignore: 
    293            init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) 
    294 
    295        elif config.extra == Extra.allow: 
    296            for k, v in kwargs.items(): 
    297                self.__dict__.setdefault(k, v) 
    298            init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) 
    299 
    300        else: 
    301            init(self, *args, **kwargs) 
    302 
    303    if hasattr(dc_cls, '__post_init__'): 
    304        try: 
    305            post_init = dc_cls.__post_init__.__wrapped__  # type: ignore[attr-defined] 
    306        except AttributeError: 
    307            post_init = dc_cls.__post_init__ 
    308 
    309        @wraps(post_init) 
    310        def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: 
    311            if config.post_init_call == 'before_validation': 
    312                post_init(self, *args, **kwargs) 
    313 
    314            if self.__class__.__pydantic_run_validation__: 
    315                self.__pydantic_validate_values__() 
    316                if hasattr(self, '__post_init_post_parse__'): 
    317                    self.__post_init_post_parse__(*args, **kwargs) 
    318 
    319            if config.post_init_call == 'after_validation': 
    320                post_init(self, *args, **kwargs) 
    321 
    322        setattr(dc_cls, '__init__', handle_extra_init) 
    323        setattr(dc_cls, '__post_init__', new_post_init) 
    324 
    325    else: 
    326 
    327        @wraps(init) 
    328        def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: 
    329            handle_extra_init(self, *args, **kwargs) 
    330 
    331            if self.__class__.__pydantic_run_validation__: 
    332                self.__pydantic_validate_values__() 
    333 
    334            if hasattr(self, '__post_init_post_parse__'): 
    335                # We need to find again the initvars. To do that we use `__dataclass_fields__` instead of 
    336                # public method `dataclasses.fields` 
    337 
    338                # get all initvars and their default values 
    339                initvars_and_values: Dict[str, Any] = {} 
    340                for i, f in enumerate(self.__class__.__dataclass_fields__.values()): 
    341                    if f._field_type is dataclasses._FIELD_INITVAR:  # type: ignore[attr-defined] 
    342                        try: 
    343                            # set arg value by default 
    344                            initvars_and_values[f.name] = args[i] 
    345                        except IndexError: 
    346                            initvars_and_values[f.name] = kwargs.get(f.name, f.default) 
    347 
    348                self.__post_init_post_parse__(**initvars_and_values) 
    349 
    350        setattr(dc_cls, '__init__', new_init) 
    351 
    352    setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init)) 
    353    setattr(dc_cls, '__pydantic_initialised__', False) 
    354    setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc)) 
    355    setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values) 
    356    setattr(dc_cls, '__validate__', classmethod(_validate_dataclass)) 
    357    setattr(dc_cls, '__get_validators__', classmethod(_get_validators)) 
    358 
    359    if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen: 
    360        setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr) 
    361 
    362 
    363def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator': 
    364    yield cls.__validate__ 
    365 
    366 
    367def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT': 
    368    with set_validation(cls, True): 
    369        if isinstance(v, cls): 
    370            v.__pydantic_validate_values__() 
    371            return v 
    372        elif isinstance(v, (list, tuple)): 
    373            return cls(*v) 
    374        elif isinstance(v, dict): 
    375            return cls(**v) 
    376        else: 
    377            raise DataclassTypeError(class_name=cls.__name__) 
    378 
    379 
    380def create_pydantic_model_from_dataclass( 
    381    dc_cls: Type['Dataclass'], 
    382    config: Type[Any] = BaseConfig, 
    383    dc_cls_doc: Optional[str] = None, 
    384) -> Type['BaseModel']: 
    385    field_definitions: Dict[str, Any] = {} 
    386    for field in dataclasses.fields(dc_cls): 
    387        default: Any = Undefined 
    388        default_factory: Optional['NoArgAnyCallable'] = None 
    389        field_info: FieldInfo 
    390 
    391        if field.default is not dataclasses.MISSING: 
    392            default = field.default 
    393        elif field.default_factory is not dataclasses.MISSING: 
    394            default_factory = field.default_factory 
    395        else: 
    396            default = Required 
    397 
    398        if isinstance(default, FieldInfo): 
    399            field_info = default 
    400            dc_cls.__pydantic_has_field_info_default__ = True 
    401        else: 
    402            field_info = Field(default=default, default_factory=default_factory, **field.metadata) 
    403 
    404        field_definitions[field.name] = (field.type, field_info) 
    405 
    406    validators = gather_all_validators(dc_cls) 
    407    model: Type['BaseModel'] = create_model( 
    408        dc_cls.__name__, 
    409        __config__=config, 
    410        __module__=dc_cls.__module__, 
    411        __validators__=validators, 
    412        __cls_kwargs__={'__resolve_forward_refs__': False}, 
    413        **field_definitions, 
    414    ) 
    415    model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or '' 
    416    return model 
    417 
    418 
    419if sys.version_info >= (3, 8): 
    420 
    421    def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool: 
    422        return isinstance(getattr(type(obj), k, None), cached_property) 
    423 
    424else: 
    425 
    426    def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool: 
    427        return False 
    428 
    429 
    430def _dataclass_validate_values(self: 'Dataclass') -> None: 
    431    # validation errors can occur if this function is called twice on an already initialised dataclass. 
    432    # for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property 
    433    if getattr(self, '__pydantic_initialised__'): 
    434        return 
    435    if getattr(self, '__pydantic_has_field_info_default__', False): 
    436        # We need to remove `FieldInfo` values since they are not valid as input 
    437        # It's ok to do that because they are obviously the default values! 
    438        input_data = { 
    439            k: v 
    440            for k, v in self.__dict__.items() 
    441            if not (isinstance(v, FieldInfo) or _is_field_cached_property(self, k)) 
    442        } 
    443    else: 
    444        input_data = {k: v for k, v in self.__dict__.items() if not _is_field_cached_property(self, k)} 
    445    d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__) 
    446    if validation_error: 
    447        raise validation_error 
    448    self.__dict__.update(d) 
    449    object.__setattr__(self, '__pydantic_initialised__', True) 
    450 
    451 
    452def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None: 
    453    if self.__pydantic_initialised__: 
    454        d = dict(self.__dict__) 
    455        d.pop(name, None) 
    456        known_field = self.__pydantic_model__.__fields__.get(name, None) 
    457        if known_field: 
    458            value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__) 
    459            if error_: 
    460                raise ValidationError([error_], self.__class__) 
    461 
    462    object.__setattr__(self, name, value) 
    463 
    464 
    465def is_builtin_dataclass(_cls: Type[Any]) -> bool: 
    466    """ 
    467    Whether a class is a stdlib dataclass 
    468    (useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass) 
    469 
    470    we check that 
    471    - `_cls` is a dataclass 
    472    - `_cls` is not a processed pydantic dataclass (with a basemodel attached) 
    473    - `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass 
    474    e.g. 
    475    ``` 
    476    @dataclasses.dataclass 
    477    class A: 
    478        x: int 
    479 
    480    @pydantic.dataclasses.dataclass 
    481    class B(A): 
    482        y: int 
    483    ``` 
    484    In this case, when we first check `B`, we make an extra check and look at the annotations ('y'), 
    485    which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x') 
    486    """ 
    487    return ( 
    488        dataclasses.is_dataclass(_cls) 
    489        and not hasattr(_cls, '__pydantic_model__') 
    490        and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {}))) 
    491    ) 
    492 
    493 
    494def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator': 
    495    """ 
    496    Create a pydantic.dataclass from a builtin dataclass to add type validation 
    497    and yield the validators 
    498    It retrieves the parameters of the dataclass and forwards them to the newly created dataclass 
    499    """ 
    500    yield from _get_validators(dataclass(dc_cls, config=config, use_proxy=True))