1"""Validator functions for standard library types.
2
3Import of this module is deferred since it contains imports of many standard library modules.
4"""
5
6from __future__ import annotations as _annotations
7
8import collections.abc
9import math
10import re
11import typing
12from collections.abc import Sequence
13from decimal import Decimal
14from fractions import Fraction
15from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
16from typing import Any, Callable, TypeVar, Union, cast
17from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
18
19import typing_extensions
20from pydantic_core import PydanticCustomError, PydanticKnownError, core_schema
21from typing_extensions import get_args, get_origin
22from typing_inspection import typing_objects
23
24from pydantic._internal._import_utils import import_cached_field_info
25from pydantic.errors import PydanticSchemaGenerationError
26
27
28def sequence_validator(
29 input_value: Sequence[Any],
30 /,
31 validator: core_schema.ValidatorFunctionWrapHandler,
32) -> Sequence[Any]:
33 """Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
34 value_type = type(input_value)
35
36 # We don't accept any plain string as a sequence
37 # Relevant issue: https://github.com/pydantic/pydantic/issues/5595
38 if issubclass(value_type, (str, bytes)):
39 raise PydanticCustomError(
40 'sequence_str',
41 "'{type_name}' instances are not allowed as a Sequence value",
42 {'type_name': value_type.__name__},
43 )
44
45 # TODO: refactor sequence validation to validate with either a list or a tuple
46 # schema, depending on the type of the value.
47 # Additionally, we should be able to remove one of either this validator or the
48 # SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic).
49 # Effectively, a refactor for sequence validation is needed.
50 if value_type is tuple:
51 input_value = list(input_value)
52
53 v_list = validator(input_value)
54
55 # the rest of the logic is just re-creating the original type from `v_list`
56 if value_type is list:
57 return v_list
58 elif issubclass(value_type, range):
59 # return the list as we probably can't re-create the range
60 return v_list
61 elif value_type is tuple:
62 return tuple(v_list)
63 else:
64 # best guess at how to re-create the original type, more custom construction logic might be required
65 return value_type(v_list) # type: ignore[call-arg]
66
67
68def import_string(value: Any) -> Any:
69 if isinstance(value, str):
70 try:
71 return _import_string_logic(value)
72 except ImportError as e:
73 raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
74 else:
75 # otherwise we just return the value and let the next validator do the rest of the work
76 return value
77
78
79def _import_string_logic(dotted_path: str) -> Any:
80 """Inspired by uvicorn — dotted paths should include a colon before the final item if that item is not a module.
81 (This is necessary to distinguish between a submodule and an attribute when there is a conflict.).
82
83 If the dotted path does not include a colon and the final item is not a valid module, importing as an attribute
84 rather than a submodule will be attempted automatically.
85
86 So, for example, the following values of `dotted_path` result in the following returned values:
87 * 'collections': <module 'collections'>
88 * 'collections.abc': <module 'collections.abc'>
89 * 'collections.abc:Mapping': <class 'collections.abc.Mapping'>
90 * `collections.abc.Mapping`: <class 'collections.abc.Mapping'> (though this is a bit slower than the previous line)
91
92 An error will be raised under any of the following scenarios:
93 * `dotted_path` contains more than one colon (e.g., 'collections:abc:Mapping')
94 * the substring of `dotted_path` before the colon is not a valid module in the environment (e.g., '123:Mapping')
95 * the substring of `dotted_path` after the colon is not an attribute of the module (e.g., 'collections:abc123')
96 """
97 from importlib import import_module
98
99 components = dotted_path.strip().split(':')
100 if len(components) > 2:
101 raise ImportError(f"Import strings should have at most one ':'; received {dotted_path!r}")
102 attribute = None
103 if len(components) == 2:
104 attribute = components[1]
105 module_path = components[0]
106 if not module_path:
107 raise ImportError(f'Import strings should have a nonempty module name; received {dotted_path!r}')
108
109 try:
110 module = import_module(module_path)
111 except ModuleNotFoundError:
112 if attribute is None and '.' in module_path:
113 # Try interpreting the final dotted segment as an attribute, not a submodule
114 maybe_module_path, maybe_attribute = module_path.rsplit('.', 1)
115
116 try:
117 return _import_string_logic(f'{maybe_module_path}:{maybe_attribute}')
118 except ImportError:
119 pass
120 raise
121
122 if attribute is not None:
123 try:
124 return getattr(module, attribute)
125 except AttributeError as e:
126 raise ImportError(f'cannot import name {attribute!r} from {module_path!r}') from e
127 else:
128 return module
129
130
131def pattern_either_validator(input_value: Any, /) -> re.Pattern[Any]:
132 if isinstance(input_value, re.Pattern):
133 return input_value
134 elif isinstance(input_value, (str, bytes)):
135 # todo strict mode
136 return compile_pattern(input_value) # type: ignore
137 else:
138 raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
139
140
141def pattern_str_validator(input_value: Any, /) -> re.Pattern[str]:
142 if isinstance(input_value, re.Pattern):
143 if isinstance(input_value.pattern, str):
144 return input_value
145 else:
146 raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
147 elif isinstance(input_value, str):
148 return compile_pattern(input_value)
149 elif isinstance(input_value, bytes):
150 raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
151 else:
152 raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
153
154
155def pattern_bytes_validator(input_value: Any, /) -> re.Pattern[bytes]:
156 if isinstance(input_value, re.Pattern):
157 if isinstance(input_value.pattern, bytes):
158 return input_value
159 else:
160 raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
161 elif isinstance(input_value, bytes):
162 return compile_pattern(input_value)
163 elif isinstance(input_value, str):
164 raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
165 else:
166 raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
167
168
169PatternType = TypeVar('PatternType', str, bytes)
170
171
172def compile_pattern(pattern: PatternType) -> re.Pattern[PatternType]:
173 try:
174 return re.compile(pattern)
175 except re.error:
176 raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
177
178
179def ip_v4_address_validator(input_value: Any, /) -> IPv4Address:
180 if isinstance(input_value, IPv4Address):
181 return input_value
182
183 try:
184 return IPv4Address(input_value)
185 except ValueError:
186 raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
187
188
189def ip_v6_address_validator(input_value: Any, /) -> IPv6Address:
190 if isinstance(input_value, IPv6Address):
191 return input_value
192
193 try:
194 return IPv6Address(input_value)
195 except ValueError:
196 raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
197
198
199def ip_v4_network_validator(input_value: Any, /) -> IPv4Network:
200 """Assume IPv4Network initialised with a default `strict` argument.
201
202 See more:
203 https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
204 """
205 if isinstance(input_value, IPv4Network):
206 return input_value
207
208 try:
209 return IPv4Network(input_value)
210 except ValueError:
211 raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
212
213
214def ip_v6_network_validator(input_value: Any, /) -> IPv6Network:
215 """Assume IPv6Network initialised with a default `strict` argument.
216
217 See more:
218 https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
219 """
220 if isinstance(input_value, IPv6Network):
221 return input_value
222
223 try:
224 return IPv6Network(input_value)
225 except ValueError:
226 raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
227
228
229def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface:
230 if isinstance(input_value, IPv4Interface):
231 return input_value
232
233 try:
234 return IPv4Interface(input_value)
235 except ValueError:
236 raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
237
238
239def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface:
240 if isinstance(input_value, IPv6Interface):
241 return input_value
242
243 try:
244 return IPv6Interface(input_value)
245 except ValueError:
246 raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
247
248
249def fraction_validator(input_value: Any, /) -> Fraction:
250 if isinstance(input_value, Fraction):
251 return input_value
252
253 try:
254 return Fraction(input_value)
255 except ValueError:
256 raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction')
257
258
259def forbid_inf_nan_check(x: Any) -> Any:
260 if not math.isfinite(x):
261 raise PydanticKnownError('finite_number')
262 return x
263
264
265def _safe_repr(v: Any) -> int | float | str:
266 """The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta.
267
268 See tests/test_types.py::test_annotated_metadata_any_order for some context.
269 """
270 if isinstance(v, (int, float, str)):
271 return v
272 return repr(v)
273
274
275def greater_than_validator(x: Any, gt: Any) -> Any:
276 try:
277 if not (x > gt):
278 raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)})
279 return x
280 except TypeError:
281 raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}")
282
283
284def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
285 try:
286 if not (x >= ge):
287 raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)})
288 return x
289 except TypeError:
290 raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}")
291
292
293def less_than_validator(x: Any, lt: Any) -> Any:
294 try:
295 if not (x < lt):
296 raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)})
297 return x
298 except TypeError:
299 raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}")
300
301
302def less_than_or_equal_validator(x: Any, le: Any) -> Any:
303 try:
304 if not (x <= le):
305 raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)})
306 return x
307 except TypeError:
308 raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}")
309
310
311def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
312 try:
313 if x % multiple_of:
314 raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)})
315 return x
316 except TypeError:
317 raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}")
318
319
320def min_length_validator(x: Any, min_length: Any) -> Any:
321 try:
322 if not (len(x) >= min_length):
323 raise PydanticKnownError(
324 'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}
325 )
326 return x
327 except TypeError:
328 raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}")
329
330
331def max_length_validator(x: Any, max_length: Any) -> Any:
332 try:
333 if len(x) > max_length:
334 raise PydanticKnownError(
335 'too_long',
336 {'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
337 )
338 return x
339 except TypeError:
340 raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}")
341
342
343def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]:
344 """Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance.
345
346 This function handles both normalized and non-normalized Decimal instances.
347 Example: Decimal('1.230') -> 4 digits, 3 decimal places
348
349 Args:
350 decimal (Decimal): The decimal number to analyze.
351
352 Returns:
353 tuple[int, int]: A tuple containing the number of decimal places and total digits.
354
355 Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation
356 of the number of decimals and digits together.
357 """
358 try:
359 decimal_tuple = decimal.as_tuple()
360
361 assert isinstance(decimal_tuple.exponent, int)
362
363 exponent = decimal_tuple.exponent
364 num_digits = len(decimal_tuple.digits)
365
366 if exponent >= 0:
367 # A positive exponent adds that many trailing zeros
368 # Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits
369 num_digits += exponent
370 decimal_places = 0
371 else:
372 # If the absolute value of the negative exponent is larger than the
373 # number of digits, then it's the same as the number of digits,
374 # because it'll consume all the digits in digit_tuple and then
375 # add abs(exponent) - len(digit_tuple) leading zeros after the decimal point.
376 # Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits
377 # Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits
378 decimal_places = abs(exponent)
379 num_digits = max(num_digits, decimal_places)
380
381 return decimal_places, num_digits
382 except (AssertionError, AttributeError):
383 raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}')
384
385
386def max_digits_validator(x: Any, max_digits: Any) -> Any:
387 try:
388 _, num_digits = _extract_decimal_digits_info(x)
389 _, normalized_num_digits = _extract_decimal_digits_info(x.normalize())
390 if (num_digits > max_digits) and (normalized_num_digits > max_digits):
391 raise PydanticKnownError(
392 'decimal_max_digits',
393 {'max_digits': max_digits},
394 )
395 return x
396 except TypeError:
397 raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}")
398
399
400def decimal_places_validator(x: Any, decimal_places: Any) -> Any:
401 try:
402 decimal_places_, _ = _extract_decimal_digits_info(x)
403 if decimal_places_ > decimal_places:
404 normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize())
405 if normalized_decimal_places > decimal_places:
406 raise PydanticKnownError(
407 'decimal_max_places',
408 {'decimal_places': decimal_places},
409 )
410 return x
411 except TypeError:
412 raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}")
413
414
415def deque_validator(input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> collections.deque[Any]:
416 return collections.deque(handler(input_value), maxlen=getattr(input_value, 'maxlen', None))
417
418
419def defaultdict_validator(
420 input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
421) -> collections.defaultdict[Any, Any]:
422 if isinstance(input_value, collections.defaultdict):
423 default_factory = input_value.default_factory
424 return collections.defaultdict(default_factory, handler(input_value))
425 else:
426 return collections.defaultdict(default_default_factory, handler(input_value))
427
428
429def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
430 FieldInfo = import_cached_field_info()
431
432 values_type_origin = get_origin(values_source_type)
433
434 def infer_default() -> Callable[[], Any]:
435 allowed_default_types: dict[Any, Any] = {
436 tuple: tuple,
437 collections.abc.Sequence: tuple,
438 collections.abc.MutableSequence: list,
439 list: list,
440 typing.Sequence: list,
441 set: set,
442 typing.MutableSet: set,
443 collections.abc.MutableSet: set,
444 collections.abc.Set: frozenset,
445 typing.MutableMapping: dict,
446 typing.Mapping: dict,
447 collections.abc.Mapping: dict,
448 collections.abc.MutableMapping: dict,
449 float: float,
450 int: int,
451 str: str,
452 bool: bool,
453 }
454 values_type = values_type_origin or values_source_type
455 instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
456 if typing_objects.is_typevar(values_type):
457
458 def type_var_default_factory() -> None:
459 raise RuntimeError(
460 'Generic defaultdict cannot be used without a concrete value type or an'
461 ' explicit default factory, ' + instructions
462 )
463
464 return type_var_default_factory
465 elif values_type not in allowed_default_types:
466 # a somewhat subjective set of types that have reasonable default values
467 allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
468 raise PydanticSchemaGenerationError(
469 f'Unable to infer a default factory for keys of type {values_source_type}.'
470 f' Only {allowed_msg} are supported, other types require an explicit default factory'
471 ' ' + instructions
472 )
473 return allowed_default_types[values_type]
474
475 # Assume Annotated[..., Field(...)]
476 if typing_objects.is_annotated(values_type_origin):
477 field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
478 else:
479 field_info = None
480 if field_info and field_info.default_factory:
481 # Assume the default factory does not take any argument:
482 default_default_factory = cast(Callable[[], Any], field_info.default_factory)
483 else:
484 default_default_factory = infer_default()
485 return default_default_factory
486
487
488def validate_str_is_valid_iana_tz(value: Any, /) -> ZoneInfo:
489 if isinstance(value, ZoneInfo):
490 return value
491 try:
492 return ZoneInfo(value)
493 except (ZoneInfoNotFoundError, ValueError, TypeError):
494 raise PydanticCustomError('zoneinfo_str', 'invalid timezone: {value}', {'value': value})
495
496
497NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = {
498 'gt': greater_than_validator,
499 'ge': greater_than_or_equal_validator,
500 'lt': less_than_validator,
501 'le': less_than_or_equal_validator,
502 'multiple_of': multiple_of_validator,
503 'min_length': min_length_validator,
504 'max_length': max_length_validator,
505 'max_digits': max_digits_validator,
506 'decimal_places': decimal_places_validator,
507}
508
509IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface]
510
511IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = {
512 IPv4Address: ip_v4_address_validator,
513 IPv6Address: ip_v6_address_validator,
514 IPv4Network: ip_v4_network_validator,
515 IPv6Network: ip_v6_network_validator,
516 IPv4Interface: ip_v4_interface_validator,
517 IPv6Interface: ip_v6_interface_validator,
518}
519
520MAPPING_ORIGIN_MAP: dict[Any, Any] = {
521 typing.DefaultDict: collections.defaultdict, # noqa: UP006
522 collections.defaultdict: collections.defaultdict,
523 typing.OrderedDict: collections.OrderedDict, # noqa: UP006
524 collections.OrderedDict: collections.OrderedDict,
525 typing_extensions.OrderedDict: collections.OrderedDict,
526 typing.Counter: collections.Counter,
527 collections.Counter: collections.Counter,
528 # this doesn't handle subclasses of these
529 typing.Mapping: dict,
530 typing.MutableMapping: dict,
531 # parametrized typing.{Mutable}Mapping creates one of these
532 collections.abc.Mapping: dict,
533 collections.abc.MutableMapping: dict,
534}