1"""Validator functions for standard library types.
2
3Import of this module is deferred since it contains imports of many standard library modules.
4"""
5
6from __future__ import annotations as _annotations
7
8import collections.abc
9import math
10import re
11import typing
12from collections.abc import Sequence
13from decimal import Decimal
14from fractions import Fraction
15from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
16from typing import Any, Callable, TypeVar, Union, cast
17from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
18
19import typing_extensions
20from pydantic_core import PydanticCustomError, PydanticKnownError, core_schema
21from typing_extensions import get_args, get_origin
22from typing_inspection import typing_objects
23
24from pydantic._internal._import_utils import import_cached_field_info
25from pydantic.errors import PydanticSchemaGenerationError
26
27
28def sequence_validator(
29 input_value: Sequence[Any],
30 /,
31 validator: core_schema.ValidatorFunctionWrapHandler,
32) -> Sequence[Any]:
33 """Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
34 value_type = type(input_value)
35
36 # We don't accept any plain string as a sequence
37 # Relevant issue: https://github.com/pydantic/pydantic/issues/5595
38 if issubclass(value_type, (str, bytes)):
39 raise PydanticCustomError(
40 'sequence_str',
41 "'{type_name}' instances are not allowed as a Sequence value",
42 {'type_name': value_type.__name__},
43 )
44
45 # TODO: refactor sequence validation to validate with either a list or a tuple
46 # schema, depending on the type of the value.
47 # Additionally, we should be able to remove one of either this validator or the
48 # SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic).
49 # Effectively, a refactor for sequence validation is needed.
50 if value_type is tuple:
51 input_value = list(input_value)
52
53 v_list = validator(input_value)
54
55 # the rest of the logic is just re-creating the original type from `v_list`
56 if value_type is list:
57 return v_list
58 elif issubclass(value_type, range):
59 # return the list as we probably can't re-create the range
60 return v_list
61 elif value_type is tuple:
62 return tuple(v_list)
63 else:
64 # best guess at how to re-create the original type, more custom construction logic might be required
65 return value_type(v_list) # type: ignore[call-arg]
66
67
68def import_string(value: Any) -> Any:
69 if isinstance(value, str):
70 try:
71 return _import_string_logic(value)
72 except ImportError as e:
73 raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
74 else:
75 # otherwise we just return the value and let the next validator do the rest of the work
76 return value
77
78
79def _import_string_logic(dotted_path: str) -> Any:
80 """Inspired by uvicorn — dotted paths should include a colon before the final item if that item is not a module.
81 (This is necessary to distinguish between a submodule and an attribute when there is a conflict.).
82
83 If the dotted path does not include a colon and the final item is not a valid module, importing as an attribute
84 rather than a submodule will be attempted automatically.
85
86 So, for example, the following values of `dotted_path` result in the following returned values:
87 * 'collections': <module 'collections'>
88 * 'collections.abc': <module 'collections.abc'>
89 * 'collections.abc:Mapping': <class 'collections.abc.Mapping'>
90 * `collections.abc.Mapping`: <class 'collections.abc.Mapping'> (though this is a bit slower than the previous line)
91
92 An error will be raised under any of the following scenarios:
93 * `dotted_path` contains more than one colon (e.g., 'collections:abc:Mapping')
94 * the substring of `dotted_path` before the colon is not a valid module in the environment (e.g., '123:Mapping')
95 * the substring of `dotted_path` after the colon is not an attribute of the module (e.g., 'collections:abc123')
96 """
97 from importlib import import_module
98
99 components = dotted_path.strip().split(':')
100 if len(components) > 2:
101 raise ImportError(f"Import strings should have at most one ':'; received {dotted_path!r}")
102
103 module_path = components[0]
104 if not module_path:
105 raise ImportError(f'Import strings should have a nonempty module name; received {dotted_path!r}')
106
107 try:
108 module = import_module(module_path)
109 except ModuleNotFoundError as e:
110 if '.' in module_path:
111 # Check if it would be valid if the final item was separated from its module with a `:`
112 maybe_module_path, maybe_attribute = dotted_path.strip().rsplit('.', 1)
113 try:
114 return _import_string_logic(f'{maybe_module_path}:{maybe_attribute}')
115 except ImportError:
116 pass
117 raise ImportError(f'No module named {module_path!r}') from e
118 raise e
119
120 if len(components) > 1:
121 attribute = components[1]
122 try:
123 return getattr(module, attribute)
124 except AttributeError as e:
125 raise ImportError(f'cannot import name {attribute!r} from {module_path!r}') from e
126 else:
127 return module
128
129
130def pattern_either_validator(input_value: Any, /) -> re.Pattern[Any]:
131 if isinstance(input_value, re.Pattern):
132 return input_value
133 elif isinstance(input_value, (str, bytes)):
134 # todo strict mode
135 return compile_pattern(input_value) # type: ignore
136 else:
137 raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
138
139
140def pattern_str_validator(input_value: Any, /) -> re.Pattern[str]:
141 if isinstance(input_value, re.Pattern):
142 if isinstance(input_value.pattern, str):
143 return input_value
144 else:
145 raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
146 elif isinstance(input_value, str):
147 return compile_pattern(input_value)
148 elif isinstance(input_value, bytes):
149 raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
150 else:
151 raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
152
153
154def pattern_bytes_validator(input_value: Any, /) -> re.Pattern[bytes]:
155 if isinstance(input_value, re.Pattern):
156 if isinstance(input_value.pattern, bytes):
157 return input_value
158 else:
159 raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
160 elif isinstance(input_value, bytes):
161 return compile_pattern(input_value)
162 elif isinstance(input_value, str):
163 raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
164 else:
165 raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
166
167
168PatternType = TypeVar('PatternType', str, bytes)
169
170
171def compile_pattern(pattern: PatternType) -> re.Pattern[PatternType]:
172 try:
173 return re.compile(pattern)
174 except re.error:
175 raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
176
177
178def ip_v4_address_validator(input_value: Any, /) -> IPv4Address:
179 if isinstance(input_value, IPv4Address):
180 return input_value
181
182 try:
183 return IPv4Address(input_value)
184 except ValueError:
185 raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
186
187
188def ip_v6_address_validator(input_value: Any, /) -> IPv6Address:
189 if isinstance(input_value, IPv6Address):
190 return input_value
191
192 try:
193 return IPv6Address(input_value)
194 except ValueError:
195 raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
196
197
198def ip_v4_network_validator(input_value: Any, /) -> IPv4Network:
199 """Assume IPv4Network initialised with a default `strict` argument.
200
201 See more:
202 https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
203 """
204 if isinstance(input_value, IPv4Network):
205 return input_value
206
207 try:
208 return IPv4Network(input_value)
209 except ValueError:
210 raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
211
212
213def ip_v6_network_validator(input_value: Any, /) -> IPv6Network:
214 """Assume IPv6Network initialised with a default `strict` argument.
215
216 See more:
217 https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
218 """
219 if isinstance(input_value, IPv6Network):
220 return input_value
221
222 try:
223 return IPv6Network(input_value)
224 except ValueError:
225 raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
226
227
228def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface:
229 if isinstance(input_value, IPv4Interface):
230 return input_value
231
232 try:
233 return IPv4Interface(input_value)
234 except ValueError:
235 raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
236
237
238def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface:
239 if isinstance(input_value, IPv6Interface):
240 return input_value
241
242 try:
243 return IPv6Interface(input_value)
244 except ValueError:
245 raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
246
247
248def fraction_validator(input_value: Any, /) -> Fraction:
249 if isinstance(input_value, Fraction):
250 return input_value
251
252 try:
253 return Fraction(input_value)
254 except ValueError:
255 raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction')
256
257
258def forbid_inf_nan_check(x: Any) -> Any:
259 if not math.isfinite(x):
260 raise PydanticKnownError('finite_number')
261 return x
262
263
264def _safe_repr(v: Any) -> int | float | str:
265 """The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta.
266
267 See tests/test_types.py::test_annotated_metadata_any_order for some context.
268 """
269 if isinstance(v, (int, float, str)):
270 return v
271 return repr(v)
272
273
274def greater_than_validator(x: Any, gt: Any) -> Any:
275 try:
276 if not (x > gt):
277 raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)})
278 return x
279 except TypeError:
280 raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}")
281
282
283def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
284 try:
285 if not (x >= ge):
286 raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)})
287 return x
288 except TypeError:
289 raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}")
290
291
292def less_than_validator(x: Any, lt: Any) -> Any:
293 try:
294 if not (x < lt):
295 raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)})
296 return x
297 except TypeError:
298 raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}")
299
300
301def less_than_or_equal_validator(x: Any, le: Any) -> Any:
302 try:
303 if not (x <= le):
304 raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)})
305 return x
306 except TypeError:
307 raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}")
308
309
310def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
311 try:
312 if x % multiple_of:
313 raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)})
314 return x
315 except TypeError:
316 raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}")
317
318
319def min_length_validator(x: Any, min_length: Any) -> Any:
320 try:
321 if not (len(x) >= min_length):
322 raise PydanticKnownError(
323 'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}
324 )
325 return x
326 except TypeError:
327 raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}")
328
329
330def max_length_validator(x: Any, max_length: Any) -> Any:
331 try:
332 if len(x) > max_length:
333 raise PydanticKnownError(
334 'too_long',
335 {'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
336 )
337 return x
338 except TypeError:
339 raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}")
340
341
342def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]:
343 """Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance.
344
345 This function handles both normalized and non-normalized Decimal instances.
346 Example: Decimal('1.230') -> 4 digits, 3 decimal places
347
348 Args:
349 decimal (Decimal): The decimal number to analyze.
350
351 Returns:
352 tuple[int, int]: A tuple containing the number of decimal places and total digits.
353
354 Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation
355 of the number of decimals and digits together.
356 """
357 try:
358 decimal_tuple = decimal.as_tuple()
359
360 assert isinstance(decimal_tuple.exponent, int)
361
362 exponent = decimal_tuple.exponent
363 num_digits = len(decimal_tuple.digits)
364
365 if exponent >= 0:
366 # A positive exponent adds that many trailing zeros
367 # Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits
368 num_digits += exponent
369 decimal_places = 0
370 else:
371 # If the absolute value of the negative exponent is larger than the
372 # number of digits, then it's the same as the number of digits,
373 # because it'll consume all the digits in digit_tuple and then
374 # add abs(exponent) - len(digit_tuple) leading zeros after the decimal point.
375 # Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits
376 # Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits
377 decimal_places = abs(exponent)
378 num_digits = max(num_digits, decimal_places)
379
380 return decimal_places, num_digits
381 except (AssertionError, AttributeError):
382 raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}')
383
384
385def max_digits_validator(x: Any, max_digits: Any) -> Any:
386 try:
387 _, num_digits = _extract_decimal_digits_info(x)
388 _, normalized_num_digits = _extract_decimal_digits_info(x.normalize())
389 if (num_digits > max_digits) and (normalized_num_digits > max_digits):
390 raise PydanticKnownError(
391 'decimal_max_digits',
392 {'max_digits': max_digits},
393 )
394 return x
395 except TypeError:
396 raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}")
397
398
399def decimal_places_validator(x: Any, decimal_places: Any) -> Any:
400 try:
401 decimal_places_, _ = _extract_decimal_digits_info(x)
402 if decimal_places_ > decimal_places:
403 normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize())
404 if normalized_decimal_places > decimal_places:
405 raise PydanticKnownError(
406 'decimal_max_places',
407 {'decimal_places': decimal_places},
408 )
409 return x
410 except TypeError:
411 raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}")
412
413
414def deque_validator(input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> collections.deque[Any]:
415 return collections.deque(handler(input_value), maxlen=getattr(input_value, 'maxlen', None))
416
417
418def defaultdict_validator(
419 input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
420) -> collections.defaultdict[Any, Any]:
421 if isinstance(input_value, collections.defaultdict):
422 default_factory = input_value.default_factory
423 return collections.defaultdict(default_factory, handler(input_value))
424 else:
425 return collections.defaultdict(default_default_factory, handler(input_value))
426
427
428def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
429 FieldInfo = import_cached_field_info()
430
431 values_type_origin = get_origin(values_source_type)
432
433 def infer_default() -> Callable[[], Any]:
434 allowed_default_types: dict[Any, Any] = {
435 tuple: tuple,
436 collections.abc.Sequence: tuple,
437 collections.abc.MutableSequence: list,
438 list: list,
439 typing.Sequence: list,
440 set: set,
441 typing.MutableSet: set,
442 collections.abc.MutableSet: set,
443 collections.abc.Set: frozenset,
444 typing.MutableMapping: dict,
445 typing.Mapping: dict,
446 collections.abc.Mapping: dict,
447 collections.abc.MutableMapping: dict,
448 float: float,
449 int: int,
450 str: str,
451 bool: bool,
452 }
453 values_type = values_type_origin or values_source_type
454 instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
455 if typing_objects.is_typevar(values_type):
456
457 def type_var_default_factory() -> None:
458 raise RuntimeError(
459 'Generic defaultdict cannot be used without a concrete value type or an'
460 ' explicit default factory, ' + instructions
461 )
462
463 return type_var_default_factory
464 elif values_type not in allowed_default_types:
465 # a somewhat subjective set of types that have reasonable default values
466 allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
467 raise PydanticSchemaGenerationError(
468 f'Unable to infer a default factory for keys of type {values_source_type}.'
469 f' Only {allowed_msg} are supported, other types require an explicit default factory'
470 ' ' + instructions
471 )
472 return allowed_default_types[values_type]
473
474 # Assume Annotated[..., Field(...)]
475 if typing_objects.is_annotated(values_type_origin):
476 field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
477 else:
478 field_info = None
479 if field_info and field_info.default_factory:
480 # Assume the default factory does not take any argument:
481 default_default_factory = cast(Callable[[], Any], field_info.default_factory)
482 else:
483 default_default_factory = infer_default()
484 return default_default_factory
485
486
487def validate_str_is_valid_iana_tz(value: Any, /) -> ZoneInfo:
488 if isinstance(value, ZoneInfo):
489 return value
490 try:
491 return ZoneInfo(value)
492 except (ZoneInfoNotFoundError, ValueError, TypeError):
493 raise PydanticCustomError('zoneinfo_str', 'invalid timezone: {value}', {'value': value})
494
495
496NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = {
497 'gt': greater_than_validator,
498 'ge': greater_than_or_equal_validator,
499 'lt': less_than_validator,
500 'le': less_than_or_equal_validator,
501 'multiple_of': multiple_of_validator,
502 'min_length': min_length_validator,
503 'max_length': max_length_validator,
504 'max_digits': max_digits_validator,
505 'decimal_places': decimal_places_validator,
506}
507
508IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface]
509
510IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = {
511 IPv4Address: ip_v4_address_validator,
512 IPv6Address: ip_v6_address_validator,
513 IPv4Network: ip_v4_network_validator,
514 IPv6Network: ip_v6_network_validator,
515 IPv4Interface: ip_v4_interface_validator,
516 IPv6Interface: ip_v6_interface_validator,
517}
518
519MAPPING_ORIGIN_MAP: dict[Any, Any] = {
520 typing.DefaultDict: collections.defaultdict, # noqa: UP006
521 collections.defaultdict: collections.defaultdict,
522 typing.OrderedDict: collections.OrderedDict, # noqa: UP006
523 collections.OrderedDict: collections.OrderedDict,
524 typing_extensions.OrderedDict: collections.OrderedDict,
525 typing.Counter: collections.Counter,
526 collections.Counter: collections.Counter,
527 # this doesn't handle subclasses of these
528 typing.Mapping: dict,
529 typing.MutableMapping: dict,
530 # parametrized typing.{Mutable}Mapping creates one of these
531 collections.abc.Mapping: dict,
532 collections.abc.MutableMapping: dict,
533}