1"""Validator functions for standard library types. 
    2 
    3Import of this module is deferred since it contains imports of many standard library modules. 
    4""" 
    5 
    6from __future__ import annotations as _annotations 
    7 
    8import collections.abc 
    9import math 
    10import re 
    11import typing 
    12from collections.abc import Sequence 
    13from decimal import Decimal 
    14from fractions import Fraction 
    15from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network 
    16from typing import Any, Callable, TypeVar, Union, cast 
    17from zoneinfo import ZoneInfo, ZoneInfoNotFoundError 
    18 
    19import typing_extensions 
    20from pydantic_core import PydanticCustomError, PydanticKnownError, core_schema 
    21from typing_extensions import get_args, get_origin 
    22from typing_inspection import typing_objects 
    23 
    24from pydantic._internal._import_utils import import_cached_field_info 
    25from pydantic.errors import PydanticSchemaGenerationError 
    26 
    27 
    28def sequence_validator( 
    29    input_value: Sequence[Any], 
    30    /, 
    31    validator: core_schema.ValidatorFunctionWrapHandler, 
    32) -> Sequence[Any]: 
    33    """Validator for `Sequence` types, isinstance(v, Sequence) has already been called.""" 
    34    value_type = type(input_value) 
    35 
    36    # We don't accept any plain string as a sequence 
    37    # Relevant issue: https://github.com/pydantic/pydantic/issues/5595 
    38    if issubclass(value_type, (str, bytes)): 
    39        raise PydanticCustomError( 
    40            'sequence_str', 
    41            "'{type_name}' instances are not allowed as a Sequence value", 
    42            {'type_name': value_type.__name__}, 
    43        ) 
    44 
    45    # TODO: refactor sequence validation to validate with either a list or a tuple 
    46    # schema, depending on the type of the value. 
    47    # Additionally, we should be able to remove one of either this validator or the 
    48    # SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic). 
    49    # Effectively, a refactor for sequence validation is needed. 
    50    if value_type is tuple: 
    51        input_value = list(input_value) 
    52 
    53    v_list = validator(input_value) 
    54 
    55    # the rest of the logic is just re-creating the original type from `v_list` 
    56    if value_type is list: 
    57        return v_list 
    58    elif issubclass(value_type, range): 
    59        # return the list as we probably can't re-create the range 
    60        return v_list 
    61    elif value_type is tuple: 
    62        return tuple(v_list) 
    63    else: 
    64        # best guess at how to re-create the original type, more custom construction logic might be required 
    65        return value_type(v_list)  # type: ignore[call-arg] 
    66 
    67 
    68def import_string(value: Any) -> Any: 
    69    if isinstance(value, str): 
    70        try: 
    71            return _import_string_logic(value) 
    72        except ImportError as e: 
    73            raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e 
    74    else: 
    75        # otherwise we just return the value and let the next validator do the rest of the work 
    76        return value 
    77 
    78 
    79def _import_string_logic(dotted_path: str) -> Any: 
    80    """Inspired by uvicorn — dotted paths should include a colon before the final item if that item is not a module. 
    81    (This is necessary to distinguish between a submodule and an attribute when there is a conflict.). 
    82 
    83    If the dotted path does not include a colon and the final item is not a valid module, importing as an attribute 
    84    rather than a submodule will be attempted automatically. 
    85 
    86    So, for example, the following values of `dotted_path` result in the following returned values: 
    87    * 'collections': <module 'collections'> 
    88    * 'collections.abc': <module 'collections.abc'> 
    89    * 'collections.abc:Mapping': <class 'collections.abc.Mapping'> 
    90    * `collections.abc.Mapping`: <class 'collections.abc.Mapping'> (though this is a bit slower than the previous line) 
    91 
    92    An error will be raised under any of the following scenarios: 
    93    * `dotted_path` contains more than one colon (e.g., 'collections:abc:Mapping') 
    94    * the substring of `dotted_path` before the colon is not a valid module in the environment (e.g., '123:Mapping') 
    95    * the substring of `dotted_path` after the colon is not an attribute of the module (e.g., 'collections:abc123') 
    96    """ 
    97    from importlib import import_module 
    98 
    99    components = dotted_path.strip().split(':') 
    100    if len(components) > 2: 
    101        raise ImportError(f"Import strings should have at most one ':'; received {dotted_path!r}") 
    102 
    103    module_path = components[0] 
    104    if not module_path: 
    105        raise ImportError(f'Import strings should have a nonempty module name; received {dotted_path!r}') 
    106 
    107    try: 
    108        module = import_module(module_path) 
    109    except ModuleNotFoundError as e: 
    110        if '.' in module_path: 
    111            # Check if it would be valid if the final item was separated from its module with a `:` 
    112            maybe_module_path, maybe_attribute = dotted_path.strip().rsplit('.', 1) 
    113            try: 
    114                return _import_string_logic(f'{maybe_module_path}:{maybe_attribute}') 
    115            except ImportError: 
    116                pass 
    117            raise ImportError(f'No module named {module_path!r}') from e 
    118        raise e 
    119 
    120    if len(components) > 1: 
    121        attribute = components[1] 
    122        try: 
    123            return getattr(module, attribute) 
    124        except AttributeError as e: 
    125            raise ImportError(f'cannot import name {attribute!r} from {module_path!r}') from e 
    126    else: 
    127        return module 
    128 
    129 
    130def pattern_either_validator(input_value: Any, /) -> re.Pattern[Any]: 
    131    if isinstance(input_value, re.Pattern): 
    132        return input_value 
    133    elif isinstance(input_value, (str, bytes)): 
    134        # todo strict mode 
    135        return compile_pattern(input_value)  # type: ignore 
    136    else: 
    137        raise PydanticCustomError('pattern_type', 'Input should be a valid pattern') 
    138 
    139 
    140def pattern_str_validator(input_value: Any, /) -> re.Pattern[str]: 
    141    if isinstance(input_value, re.Pattern): 
    142        if isinstance(input_value.pattern, str): 
    143            return input_value 
    144        else: 
    145            raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern') 
    146    elif isinstance(input_value, str): 
    147        return compile_pattern(input_value) 
    148    elif isinstance(input_value, bytes): 
    149        raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern') 
    150    else: 
    151        raise PydanticCustomError('pattern_type', 'Input should be a valid pattern') 
    152 
    153 
    154def pattern_bytes_validator(input_value: Any, /) -> re.Pattern[bytes]: 
    155    if isinstance(input_value, re.Pattern): 
    156        if isinstance(input_value.pattern, bytes): 
    157            return input_value 
    158        else: 
    159            raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern') 
    160    elif isinstance(input_value, bytes): 
    161        return compile_pattern(input_value) 
    162    elif isinstance(input_value, str): 
    163        raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern') 
    164    else: 
    165        raise PydanticCustomError('pattern_type', 'Input should be a valid pattern') 
    166 
    167 
    168PatternType = TypeVar('PatternType', str, bytes) 
    169 
    170 
    171def compile_pattern(pattern: PatternType) -> re.Pattern[PatternType]: 
    172    try: 
    173        return re.compile(pattern) 
    174    except re.error: 
    175        raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression') 
    176 
    177 
    178def ip_v4_address_validator(input_value: Any, /) -> IPv4Address: 
    179    if isinstance(input_value, IPv4Address): 
    180        return input_value 
    181 
    182    try: 
    183        return IPv4Address(input_value) 
    184    except ValueError: 
    185        raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address') 
    186 
    187 
    188def ip_v6_address_validator(input_value: Any, /) -> IPv6Address: 
    189    if isinstance(input_value, IPv6Address): 
    190        return input_value 
    191 
    192    try: 
    193        return IPv6Address(input_value) 
    194    except ValueError: 
    195        raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address') 
    196 
    197 
    198def ip_v4_network_validator(input_value: Any, /) -> IPv4Network: 
    199    """Assume IPv4Network initialised with a default `strict` argument. 
    200 
    201    See more: 
    202    https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network 
    203    """ 
    204    if isinstance(input_value, IPv4Network): 
    205        return input_value 
    206 
    207    try: 
    208        return IPv4Network(input_value) 
    209    except ValueError: 
    210        raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network') 
    211 
    212 
    213def ip_v6_network_validator(input_value: Any, /) -> IPv6Network: 
    214    """Assume IPv6Network initialised with a default `strict` argument. 
    215 
    216    See more: 
    217    https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network 
    218    """ 
    219    if isinstance(input_value, IPv6Network): 
    220        return input_value 
    221 
    222    try: 
    223        return IPv6Network(input_value) 
    224    except ValueError: 
    225        raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network') 
    226 
    227 
    228def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface: 
    229    if isinstance(input_value, IPv4Interface): 
    230        return input_value 
    231 
    232    try: 
    233        return IPv4Interface(input_value) 
    234    except ValueError: 
    235        raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface') 
    236 
    237 
    238def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface: 
    239    if isinstance(input_value, IPv6Interface): 
    240        return input_value 
    241 
    242    try: 
    243        return IPv6Interface(input_value) 
    244    except ValueError: 
    245        raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface') 
    246 
    247 
    248def fraction_validator(input_value: Any, /) -> Fraction: 
    249    if isinstance(input_value, Fraction): 
    250        return input_value 
    251 
    252    try: 
    253        return Fraction(input_value) 
    254    except ValueError: 
    255        raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction') 
    256 
    257 
    258def forbid_inf_nan_check(x: Any) -> Any: 
    259    if not math.isfinite(x): 
    260        raise PydanticKnownError('finite_number') 
    261    return x 
    262 
    263 
    264def _safe_repr(v: Any) -> int | float | str: 
    265    """The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta. 
    266 
    267    See tests/test_types.py::test_annotated_metadata_any_order for some context. 
    268    """ 
    269    if isinstance(v, (int, float, str)): 
    270        return v 
    271    return repr(v) 
    272 
    273 
    274def greater_than_validator(x: Any, gt: Any) -> Any: 
    275    try: 
    276        if not (x > gt): 
    277            raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)}) 
    278        return x 
    279    except TypeError: 
    280        raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}") 
    281 
    282 
    283def greater_than_or_equal_validator(x: Any, ge: Any) -> Any: 
    284    try: 
    285        if not (x >= ge): 
    286            raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)}) 
    287        return x 
    288    except TypeError: 
    289        raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}") 
    290 
    291 
    292def less_than_validator(x: Any, lt: Any) -> Any: 
    293    try: 
    294        if not (x < lt): 
    295            raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)}) 
    296        return x 
    297    except TypeError: 
    298        raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}") 
    299 
    300 
    301def less_than_or_equal_validator(x: Any, le: Any) -> Any: 
    302    try: 
    303        if not (x <= le): 
    304            raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)}) 
    305        return x 
    306    except TypeError: 
    307        raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}") 
    308 
    309 
    310def multiple_of_validator(x: Any, multiple_of: Any) -> Any: 
    311    try: 
    312        if x % multiple_of: 
    313            raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)}) 
    314        return x 
    315    except TypeError: 
    316        raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}") 
    317 
    318 
    319def min_length_validator(x: Any, min_length: Any) -> Any: 
    320    try: 
    321        if not (len(x) >= min_length): 
    322            raise PydanticKnownError( 
    323                'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)} 
    324            ) 
    325        return x 
    326    except TypeError: 
    327        raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}") 
    328 
    329 
    330def max_length_validator(x: Any, max_length: Any) -> Any: 
    331    try: 
    332        if len(x) > max_length: 
    333            raise PydanticKnownError( 
    334                'too_long', 
    335                {'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)}, 
    336            ) 
    337        return x 
    338    except TypeError: 
    339        raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}") 
    340 
    341 
    342def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]: 
    343    """Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance. 
    344 
    345    This function handles both normalized and non-normalized Decimal instances. 
    346    Example: Decimal('1.230') -> 4 digits, 3 decimal places 
    347 
    348    Args: 
    349        decimal (Decimal): The decimal number to analyze. 
    350 
    351    Returns: 
    352        tuple[int, int]: A tuple containing the number of decimal places and total digits. 
    353 
    354    Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation 
    355    of the number of decimals and digits together. 
    356    """ 
    357    try: 
    358        decimal_tuple = decimal.as_tuple() 
    359 
    360        assert isinstance(decimal_tuple.exponent, int) 
    361 
    362        exponent = decimal_tuple.exponent 
    363        num_digits = len(decimal_tuple.digits) 
    364 
    365        if exponent >= 0: 
    366            # A positive exponent adds that many trailing zeros 
    367            # Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits 
    368            num_digits += exponent 
    369            decimal_places = 0 
    370        else: 
    371            # If the absolute value of the negative exponent is larger than the 
    372            # number of digits, then it's the same as the number of digits, 
    373            # because it'll consume all the digits in digit_tuple and then 
    374            # add abs(exponent) - len(digit_tuple) leading zeros after the decimal point. 
    375            # Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits 
    376            # Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits 
    377            decimal_places = abs(exponent) 
    378            num_digits = max(num_digits, decimal_places) 
    379 
    380        return decimal_places, num_digits 
    381    except (AssertionError, AttributeError): 
    382        raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}') 
    383 
    384 
    385def max_digits_validator(x: Any, max_digits: Any) -> Any: 
    386    try: 
    387        _, num_digits = _extract_decimal_digits_info(x) 
    388        _, normalized_num_digits = _extract_decimal_digits_info(x.normalize()) 
    389        if (num_digits > max_digits) and (normalized_num_digits > max_digits): 
    390            raise PydanticKnownError( 
    391                'decimal_max_digits', 
    392                {'max_digits': max_digits}, 
    393            ) 
    394        return x 
    395    except TypeError: 
    396        raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}") 
    397 
    398 
    399def decimal_places_validator(x: Any, decimal_places: Any) -> Any: 
    400    try: 
    401        decimal_places_, _ = _extract_decimal_digits_info(x) 
    402        if decimal_places_ > decimal_places: 
    403            normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize()) 
    404            if normalized_decimal_places > decimal_places: 
    405                raise PydanticKnownError( 
    406                    'decimal_max_places', 
    407                    {'decimal_places': decimal_places}, 
    408                ) 
    409        return x 
    410    except TypeError: 
    411        raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}") 
    412 
    413 
    414def deque_validator(input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> collections.deque[Any]: 
    415    return collections.deque(handler(input_value), maxlen=getattr(input_value, 'maxlen', None)) 
    416 
    417 
    418def defaultdict_validator( 
    419    input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any] 
    420) -> collections.defaultdict[Any, Any]: 
    421    if isinstance(input_value, collections.defaultdict): 
    422        default_factory = input_value.default_factory 
    423        return collections.defaultdict(default_factory, handler(input_value)) 
    424    else: 
    425        return collections.defaultdict(default_default_factory, handler(input_value)) 
    426 
    427 
    428def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]: 
    429    FieldInfo = import_cached_field_info() 
    430 
    431    values_type_origin = get_origin(values_source_type) 
    432 
    433    def infer_default() -> Callable[[], Any]: 
    434        allowed_default_types: dict[Any, Any] = { 
    435            tuple: tuple, 
    436            collections.abc.Sequence: tuple, 
    437            collections.abc.MutableSequence: list, 
    438            list: list, 
    439            typing.Sequence: list, 
    440            set: set, 
    441            typing.MutableSet: set, 
    442            collections.abc.MutableSet: set, 
    443            collections.abc.Set: frozenset, 
    444            typing.MutableMapping: dict, 
    445            typing.Mapping: dict, 
    446            collections.abc.Mapping: dict, 
    447            collections.abc.MutableMapping: dict, 
    448            float: float, 
    449            int: int, 
    450            str: str, 
    451            bool: bool, 
    452        } 
    453        values_type = values_type_origin or values_source_type 
    454        instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`' 
    455        if typing_objects.is_typevar(values_type): 
    456 
    457            def type_var_default_factory() -> None: 
    458                raise RuntimeError( 
    459                    'Generic defaultdict cannot be used without a concrete value type or an' 
    460                    ' explicit default factory, ' + instructions 
    461                ) 
    462 
    463            return type_var_default_factory 
    464        elif values_type not in allowed_default_types: 
    465            # a somewhat subjective set of types that have reasonable default values 
    466            allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())]) 
    467            raise PydanticSchemaGenerationError( 
    468                f'Unable to infer a default factory for keys of type {values_source_type}.' 
    469                f' Only {allowed_msg} are supported, other types require an explicit default factory' 
    470                ' ' + instructions 
    471            ) 
    472        return allowed_default_types[values_type] 
    473 
    474    # Assume Annotated[..., Field(...)] 
    475    if typing_objects.is_annotated(values_type_origin): 
    476        field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None) 
    477    else: 
    478        field_info = None 
    479    if field_info and field_info.default_factory: 
    480        # Assume the default factory does not take any argument: 
    481        default_default_factory = cast(Callable[[], Any], field_info.default_factory) 
    482    else: 
    483        default_default_factory = infer_default() 
    484    return default_default_factory 
    485 
    486 
    487def validate_str_is_valid_iana_tz(value: Any, /) -> ZoneInfo: 
    488    if isinstance(value, ZoneInfo): 
    489        return value 
    490    try: 
    491        return ZoneInfo(value) 
    492    except (ZoneInfoNotFoundError, ValueError, TypeError): 
    493        raise PydanticCustomError('zoneinfo_str', 'invalid timezone: {value}', {'value': value}) 
    494 
    495 
    496NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = { 
    497    'gt': greater_than_validator, 
    498    'ge': greater_than_or_equal_validator, 
    499    'lt': less_than_validator, 
    500    'le': less_than_or_equal_validator, 
    501    'multiple_of': multiple_of_validator, 
    502    'min_length': min_length_validator, 
    503    'max_length': max_length_validator, 
    504    'max_digits': max_digits_validator, 
    505    'decimal_places': decimal_places_validator, 
    506} 
    507 
    508IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface] 
    509 
    510IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = { 
    511    IPv4Address: ip_v4_address_validator, 
    512    IPv6Address: ip_v6_address_validator, 
    513    IPv4Network: ip_v4_network_validator, 
    514    IPv6Network: ip_v6_network_validator, 
    515    IPv4Interface: ip_v4_interface_validator, 
    516    IPv6Interface: ip_v6_interface_validator, 
    517} 
    518 
    519MAPPING_ORIGIN_MAP: dict[Any, Any] = { 
    520    typing.DefaultDict: collections.defaultdict,  # noqa: UP006 
    521    collections.defaultdict: collections.defaultdict, 
    522    typing.OrderedDict: collections.OrderedDict,  # noqa: UP006 
    523    collections.OrderedDict: collections.OrderedDict, 
    524    typing_extensions.OrderedDict: collections.OrderedDict, 
    525    typing.Counter: collections.Counter, 
    526    collections.Counter: collections.Counter, 
    527    # this doesn't handle subclasses of these 
    528    typing.Mapping: dict, 
    529    typing.MutableMapping: dict, 
    530    # parametrized typing.{Mutable}Mapping creates one of these 
    531    collections.abc.Mapping: dict, 
    532    collections.abc.MutableMapping: dict, 
    533}