1import keyword
2import warnings
3import weakref
4from collections import OrderedDict, defaultdict, deque
5from copy import deepcopy
6from itertools import islice, zip_longest
7from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
8from typing import (
9 TYPE_CHECKING,
10 AbstractSet,
11 Any,
12 Callable,
13 Collection,
14 Dict,
15 Generator,
16 Iterable,
17 Iterator,
18 List,
19 Mapping,
20 NoReturn,
21 Optional,
22 Set,
23 Tuple,
24 Type,
25 TypeVar,
26 Union,
27)
28
29from typing_extensions import Annotated
30
31from pydantic.v1.errors import ConfigError
32from pydantic.v1.typing import (
33 NoneType,
34 WithArgsTypes,
35 all_literal_values,
36 display_as_type,
37 get_args,
38 get_origin,
39 is_literal_type,
40 is_union,
41)
42from pydantic.v1.version import version_info
43
44if TYPE_CHECKING:
45 from inspect import Signature
46 from pathlib import Path
47
48 from pydantic.v1.config import BaseConfig
49 from pydantic.v1.dataclasses import Dataclass
50 from pydantic.v1.fields import ModelField
51 from pydantic.v1.main import BaseModel
52 from pydantic.v1.typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
53
54 RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]]
55
56__all__ = (
57 'import_string',
58 'sequence_like',
59 'validate_field_name',
60 'lenient_isinstance',
61 'lenient_issubclass',
62 'in_ipython',
63 'is_valid_identifier',
64 'deep_update',
65 'update_not_none',
66 'almost_equal_floats',
67 'get_model',
68 'to_camel',
69 'to_lower_camel',
70 'is_valid_field',
71 'smart_deepcopy',
72 'PyObjectStr',
73 'Representation',
74 'GetterDict',
75 'ValueItems',
76 'version_info', # required here to match behaviour in v1.3
77 'ClassAttribute',
78 'path_type',
79 'ROOT_KEY',
80 'get_unique_discriminator_alias',
81 'get_discriminator_alias_and_values',
82 'DUNDER_ATTRIBUTES',
83)
84
85ROOT_KEY = '__root__'
86# these are types that are returned unchanged by deepcopy
87IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = {
88 int,
89 float,
90 complex,
91 str,
92 bool,
93 bytes,
94 type,
95 NoneType,
96 FunctionType,
97 BuiltinFunctionType,
98 LambdaType,
99 weakref.ref,
100 CodeType,
101 # note: including ModuleType will differ from behaviour of deepcopy by not producing error.
102 # It might be not a good idea in general, but considering that this function used only internally
103 # against default values of fields, this will allow to actually have a field with module as default value
104 ModuleType,
105 NotImplemented.__class__,
106 Ellipsis.__class__,
107}
108
109# these are types that if empty, might be copied with simple copy() instead of deepcopy()
110BUILTIN_COLLECTIONS: Set[Type[Any]] = {
111 list,
112 set,
113 tuple,
114 frozenset,
115 dict,
116 OrderedDict,
117 defaultdict,
118 deque,
119}
120
121
122def import_string(dotted_path: str) -> Any:
123 """
124 Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the
125 last name in the path. Raise ImportError if the import fails.
126 """
127 from importlib import import_module
128
129 try:
130 module_path, class_name = dotted_path.strip(' ').rsplit('.', 1)
131 except ValueError as e:
132 raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e
133
134 module = import_module(module_path)
135 try:
136 return getattr(module, class_name)
137 except AttributeError as e:
138 raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e
139
140
141def truncate(v: Union[str], *, max_len: int = 80) -> str:
142 """
143 Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long
144 """
145 warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning)
146 if isinstance(v, str) and len(v) > (max_len - 2):
147 # -3 so quote + string + … + quote has correct length
148 return (v[: (max_len - 3)] + '…').__repr__()
149 try:
150 v = v.__repr__()
151 except TypeError:
152 v = v.__class__.__repr__(v) # in case v is a type
153 if len(v) > max_len:
154 v = v[: max_len - 1] + '…'
155 return v
156
157
158def sequence_like(v: Any) -> bool:
159 return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
160
161
162def validate_field_name(bases: Iterable[Type[Any]], field_name: str) -> None:
163 """
164 Ensure that the field's name does not shadow an existing attribute of the model.
165 """
166 for base in bases:
167 if getattr(base, field_name, None):
168 raise NameError(
169 f'Field name "{field_name}" shadows a BaseModel attribute; '
170 f'use a different field name with "alias=\'{field_name}\'".'
171 )
172
173
174def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
175 try:
176 return isinstance(o, class_or_tuple) # type: ignore[arg-type]
177 except TypeError:
178 return False
179
180
181def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
182 try:
183 return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
184 except TypeError:
185 if isinstance(cls, WithArgsTypes):
186 return False
187 raise # pragma: no cover
188
189
190def in_ipython() -> bool:
191 """
192 Check whether we're in an ipython environment, including jupyter notebooks.
193 """
194 try:
195 eval('__IPYTHON__')
196 except NameError:
197 return False
198 else: # pragma: no cover
199 return True
200
201
202def is_valid_identifier(identifier: str) -> bool:
203 """
204 Checks that a string is a valid identifier and not a Python keyword.
205 :param identifier: The identifier to test.
206 :return: True if the identifier is valid.
207 """
208 return identifier.isidentifier() and not keyword.iskeyword(identifier)
209
210
211KeyType = TypeVar('KeyType')
212
213
214def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]:
215 updated_mapping = mapping.copy()
216 for updating_mapping in updating_mappings:
217 for k, v in updating_mapping.items():
218 if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
219 updated_mapping[k] = deep_update(updated_mapping[k], v)
220 else:
221 updated_mapping[k] = v
222 return updated_mapping
223
224
225def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None:
226 mapping.update({k: v for k, v in update.items() if v is not None})
227
228
229def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool:
230 """
231 Return True if two floats are almost equal
232 """
233 return abs(value_1 - value_2) <= delta
234
235
236def generate_model_signature(
237 init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig']
238) -> 'Signature':
239 """
240 Generate signature for model based on its fields
241 """
242 from inspect import Parameter, Signature, signature
243
244 from pydantic.v1.config import Extra
245
246 present_params = signature(init).parameters.values()
247 merged_params: Dict[str, Parameter] = {}
248 var_kw = None
249 use_var_kw = False
250
251 for param in islice(present_params, 1, None): # skip self arg
252 if param.kind is param.VAR_KEYWORD:
253 var_kw = param
254 continue
255 merged_params[param.name] = param
256
257 if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
258 allow_names = config.allow_population_by_field_name
259 for field_name, field in fields.items():
260 param_name = field.alias
261 if field_name in merged_params or param_name in merged_params:
262 continue
263 elif not is_valid_identifier(param_name):
264 if allow_names and is_valid_identifier(field_name):
265 param_name = field_name
266 else:
267 use_var_kw = True
268 continue
269
270 # TODO: replace annotation with actual expected types once #1055 solved
271 kwargs = {'default': field.default} if not field.required else {}
272 merged_params[param_name] = Parameter(
273 param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs
274 )
275
276 if config.extra is Extra.allow:
277 use_var_kw = True
278
279 if var_kw and use_var_kw:
280 # Make sure the parameter for extra kwargs
281 # does not have the same name as a field
282 default_model_signature = [
283 ('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
284 ('data', Parameter.VAR_KEYWORD),
285 ]
286 if [(p.name, p.kind) for p in present_params] == default_model_signature:
287 # if this is the standard model signature, use extra_data as the extra args name
288 var_kw_name = 'extra_data'
289 else:
290 # else start from var_kw
291 var_kw_name = var_kw.name
292
293 # generate a name that's definitely unique
294 while var_kw_name in fields:
295 var_kw_name += '_'
296 merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
297
298 return Signature(parameters=list(merged_params.values()), return_annotation=None)
299
300
301def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']:
302 from pydantic.v1.main import BaseModel
303
304 try:
305 model_cls = obj.__pydantic_model__ # type: ignore
306 except AttributeError:
307 model_cls = obj
308
309 if not issubclass(model_cls, BaseModel):
310 raise TypeError('Unsupported type, must be either BaseModel or dataclass')
311 return model_cls
312
313
314def to_camel(string: str) -> str:
315 return ''.join(word.capitalize() for word in string.split('_'))
316
317
318def to_lower_camel(string: str) -> str:
319 if len(string) >= 1:
320 pascal_string = to_camel(string)
321 return pascal_string[0].lower() + pascal_string[1:]
322 return string.lower()
323
324
325T = TypeVar('T')
326
327
328def unique_list(
329 input_list: Union[List[T], Tuple[T, ...]],
330 *,
331 name_factory: Callable[[T], str] = str,
332) -> List[T]:
333 """
334 Make a list unique while maintaining order.
335 We update the list if another one with the same name is set
336 (e.g. root validator overridden in subclass)
337 """
338 result: List[T] = []
339 result_names: List[str] = []
340 for v in input_list:
341 v_name = name_factory(v)
342 if v_name not in result_names:
343 result_names.append(v_name)
344 result.append(v)
345 else:
346 result[result_names.index(v_name)] = v
347
348 return result
349
350
351class PyObjectStr(str):
352 """
353 String class where repr doesn't include quotes. Useful with Representation when you want to return a string
354 representation of something that valid (or pseudo-valid) python.
355 """
356
357 def __repr__(self) -> str:
358 return str(self)
359
360
361class Representation:
362 """
363 Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details.
364
365 __pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations
366 of objects.
367 """
368
369 __slots__: Tuple[str, ...] = tuple()
370
371 def __repr_args__(self) -> 'ReprArgs':
372 """
373 Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
374
375 Can either return:
376 * name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
377 * or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
378 """
379 attrs = ((s, getattr(self, s)) for s in self.__slots__)
380 return [(a, v) for a, v in attrs if v is not None]
381
382 def __repr_name__(self) -> str:
383 """
384 Name of the instance's class, used in __repr__.
385 """
386 return self.__class__.__name__
387
388 def __repr_str__(self, join_str: str) -> str:
389 return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
390
391 def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]:
392 """
393 Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects
394 """
395 yield self.__repr_name__() + '('
396 yield 1
397 for name, value in self.__repr_args__():
398 if name is not None:
399 yield name + '='
400 yield fmt(value)
401 yield ','
402 yield 0
403 yield -1
404 yield ')'
405
406 def __str__(self) -> str:
407 return self.__repr_str__(' ')
408
409 def __repr__(self) -> str:
410 return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
411
412 def __rich_repr__(self) -> 'RichReprResult':
413 """Get fields for Rich library"""
414 for name, field_repr in self.__repr_args__():
415 if name is None:
416 yield field_repr
417 else:
418 yield name, field_repr
419
420
421class GetterDict(Representation):
422 """
423 Hack to make object's smell just enough like dicts for validate_model.
424
425 We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves.
426 """
427
428 __slots__ = ('_obj',)
429
430 def __init__(self, obj: Any):
431 self._obj = obj
432
433 def __getitem__(self, key: str) -> Any:
434 try:
435 return getattr(self._obj, key)
436 except AttributeError as e:
437 raise KeyError(key) from e
438
439 def get(self, key: Any, default: Any = None) -> Any:
440 return getattr(self._obj, key, default)
441
442 def extra_keys(self) -> Set[Any]:
443 """
444 We don't want to get any other attributes of obj if the model didn't explicitly ask for them
445 """
446 return set()
447
448 def keys(self) -> List[Any]:
449 """
450 Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python
451 dictionaries.
452 """
453 return list(self)
454
455 def values(self) -> List[Any]:
456 return [self[k] for k in self]
457
458 def items(self) -> Iterator[Tuple[str, Any]]:
459 for k in self:
460 yield k, self.get(k)
461
462 def __iter__(self) -> Iterator[str]:
463 for name in dir(self._obj):
464 if not name.startswith('_'):
465 yield name
466
467 def __len__(self) -> int:
468 return sum(1 for _ in self)
469
470 def __contains__(self, item: Any) -> bool:
471 return item in self.keys()
472
473 def __eq__(self, other: Any) -> bool:
474 return dict(self) == dict(other.items())
475
476 def __repr_args__(self) -> 'ReprArgs':
477 return [(None, dict(self))]
478
479 def __repr_name__(self) -> str:
480 return f'GetterDict[{display_as_type(self._obj)}]'
481
482
483class ValueItems(Representation):
484 """
485 Class for more convenient calculation of excluded or included fields on values.
486 """
487
488 __slots__ = ('_items', '_type')
489
490 def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None:
491 items = self._coerce_items(items)
492
493 if isinstance(value, (list, tuple)):
494 items = self._normalize_indexes(items, len(value))
495
496 self._items: 'MappingIntStrAny' = items
497
498 def is_excluded(self, item: Any) -> bool:
499 """
500 Check if item is fully excluded.
501
502 :param item: key or index of a value
503 """
504 return self.is_true(self._items.get(item))
505
506 def is_included(self, item: Any) -> bool:
507 """
508 Check if value is contained in self._items
509
510 :param item: key or index of value
511 """
512 return item in self._items
513
514 def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]:
515 """
516 :param e: key or index of element on value
517 :return: raw values for element if self._items is dict and contain needed element
518 """
519
520 item = self._items.get(e)
521 return item if not self.is_true(item) else None
522
523 def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny':
524 """
525 :param items: dict or set of indexes which will be normalized
526 :param v_length: length of sequence indexes of which will be
527
528 >>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
529 {0: True, 2: True, 3: True}
530 >>> self._normalize_indexes({'__all__': True}, 4)
531 {0: True, 1: True, 2: True, 3: True}
532 """
533
534 normalized_items: 'DictIntStrAny' = {}
535 all_items = None
536 for i, v in items.items():
537 if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)):
538 raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
539 if i == '__all__':
540 all_items = self._coerce_value(v)
541 continue
542 if not isinstance(i, int):
543 raise TypeError(
544 'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
545 'expected integer keys or keyword "__all__"'
546 )
547 normalized_i = v_length + i if i < 0 else i
548 normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
549
550 if not all_items:
551 return normalized_items
552 if self.is_true(all_items):
553 for i in range(v_length):
554 normalized_items.setdefault(i, ...)
555 return normalized_items
556 for i in range(v_length):
557 normalized_item = normalized_items.setdefault(i, {})
558 if not self.is_true(normalized_item):
559 normalized_items[i] = self.merge(all_items, normalized_item)
560 return normalized_items
561
562 @classmethod
563 def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
564 """
565 Merge a ``base`` item with an ``override`` item.
566
567 Both ``base`` and ``override`` are converted to dictionaries if possible.
568 Sets are converted to dictionaries with the sets entries as keys and
569 Ellipsis as values.
570
571 Each key-value pair existing in ``base`` is merged with ``override``,
572 while the rest of the key-value pairs are updated recursively with this function.
573
574 Merging takes place based on the "union" of keys if ``intersect`` is
575 set to ``False`` (default) and on the intersection of keys if
576 ``intersect`` is set to ``True``.
577 """
578 override = cls._coerce_value(override)
579 base = cls._coerce_value(base)
580 if override is None:
581 return base
582 if cls.is_true(base) or base is None:
583 return override
584 if cls.is_true(override):
585 return base if intersect else override
586
587 # intersection or union of keys while preserving ordering:
588 if intersect:
589 merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
590 else:
591 merge_keys = list(base) + [k for k in override if k not in base]
592
593 merged: 'DictIntStrAny' = {}
594 for k in merge_keys:
595 merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
596 if merged_item is not None:
597 merged[k] = merged_item
598
599 return merged
600
601 @staticmethod
602 def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny':
603 if isinstance(items, Mapping):
604 pass
605 elif isinstance(items, AbstractSet):
606 items = dict.fromkeys(items, ...)
607 else:
608 class_name = getattr(items, '__class__', '???')
609 assert_never(
610 items,
611 f'Unexpected type of exclude value {class_name}',
612 )
613 return items
614
615 @classmethod
616 def _coerce_value(cls, value: Any) -> Any:
617 if value is None or cls.is_true(value):
618 return value
619 return cls._coerce_items(value)
620
621 @staticmethod
622 def is_true(v: Any) -> bool:
623 return v is True or v is ...
624
625 def __repr_args__(self) -> 'ReprArgs':
626 return [(None, self._items)]
627
628
629class ClassAttribute:
630 """
631 Hide class attribute from its instances
632 """
633
634 __slots__ = (
635 'name',
636 'value',
637 )
638
639 def __init__(self, name: str, value: Any) -> None:
640 self.name = name
641 self.value = value
642
643 def __get__(self, instance: Any, owner: Type[Any]) -> None:
644 if instance is None:
645 return self.value
646 raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
647
648
649path_types = {
650 'is_dir': 'directory',
651 'is_file': 'file',
652 'is_mount': 'mount point',
653 'is_symlink': 'symlink',
654 'is_block_device': 'block device',
655 'is_char_device': 'char device',
656 'is_fifo': 'FIFO',
657 'is_socket': 'socket',
658}
659
660
661def path_type(p: 'Path') -> str:
662 """
663 Find out what sort of thing a path is.
664 """
665 assert p.exists(), 'path does not exist'
666 for method, name in path_types.items():
667 if getattr(p, method)():
668 return name
669
670 return 'unknown'
671
672
673Obj = TypeVar('Obj')
674
675
676def smart_deepcopy(obj: Obj) -> Obj:
677 """
678 Return type as is for immutable built-in types
679 Use obj.copy() for built-in empty collections
680 Use copy.deepcopy() for non-empty collections and unknown objects
681 """
682
683 obj_type = obj.__class__
684 if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
685 return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
686 try:
687 if not obj and obj_type in BUILTIN_COLLECTIONS:
688 # faster way for empty collections, no need to copy its members
689 return obj if obj_type is tuple else obj.copy() # type: ignore # tuple doesn't have copy method
690 except (TypeError, ValueError, RuntimeError):
691 # do we really dare to catch ALL errors? Seems a bit risky
692 pass
693
694 return deepcopy(obj) # slowest way when we actually might need a deepcopy
695
696
697def is_valid_field(name: str) -> bool:
698 if not name.startswith('_'):
699 return True
700 return ROOT_KEY == name
701
702
703DUNDER_ATTRIBUTES = {
704 '__annotations__',
705 '__classcell__',
706 '__doc__',
707 '__module__',
708 '__orig_bases__',
709 '__orig_class__',
710 '__qualname__',
711 '__firstlineno__',
712 '__static_attributes__',
713}
714
715
716def is_valid_private_name(name: str) -> bool:
717 return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES
718
719
720_EMPTY = object()
721
722
723def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool:
724 """
725 Check that the items of `left` are the same objects as those in `right`.
726
727 >>> a, b = object(), object()
728 >>> all_identical([a, b, a], [a, b, a])
729 True
730 >>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
731 False
732 """
733 for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
734 if left_item is not right_item:
735 return False
736 return True
737
738
739def assert_never(obj: NoReturn, msg: str) -> NoReturn:
740 """
741 Helper to make sure that we have covered all possible types.
742
743 This is mostly useful for ``mypy``, docs:
744 https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks
745 """
746 raise TypeError(msg)
747
748
749def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str:
750 """Validate that all aliases are the same and if that's the case return the alias"""
751 unique_aliases = set(all_aliases)
752 if len(unique_aliases) > 1:
753 raise ConfigError(
754 f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})'
755 )
756 return unique_aliases.pop()
757
758
759def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]:
760 """
761 Get alias and all valid values in the `Literal` type of the discriminator field
762 `tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many.
763 """
764 is_root_model = getattr(tp, '__custom_root_type__', False)
765
766 if get_origin(tp) is Annotated:
767 tp = get_args(tp)[0]
768
769 if hasattr(tp, '__pydantic_model__'):
770 tp = tp.__pydantic_model__
771
772 if is_union(get_origin(tp)):
773 alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key)
774 return alias, tuple(v for values in all_values for v in values)
775 elif is_root_model:
776 union_type = tp.__fields__[ROOT_KEY].type_
777 alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key)
778
779 if len(set(all_values)) > 1:
780 raise ConfigError(
781 f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}'
782 )
783
784 return alias, all_values[0]
785
786 else:
787 try:
788 t_discriminator_type = tp.__fields__[discriminator_key].type_
789 except AttributeError as e:
790 raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e
791 except KeyError as e:
792 raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e
793
794 if not is_literal_type(t_discriminator_type):
795 raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`')
796
797 return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type)
798
799
800def _get_union_alias_and_all_values(
801 union_type: Type[Any], discriminator_key: str
802) -> Tuple[str, Tuple[Tuple[str, ...], ...]]:
803 zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)]
804 # unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))]
805 all_aliases, all_values = zip(*zipped_aliases_values)
806 return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values