Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/pydantic/v1/utils.py: 4%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

364 statements  

1import keyword 

2import warnings 

3import weakref 

4from collections import OrderedDict, defaultdict, deque 

5from copy import deepcopy 

6from itertools import islice, zip_longest 

7from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType 

8from typing import ( 

9 TYPE_CHECKING, 

10 AbstractSet, 

11 Any, 

12 Callable, 

13 Collection, 

14 Dict, 

15 Generator, 

16 Iterable, 

17 Iterator, 

18 List, 

19 Mapping, 

20 NoReturn, 

21 Optional, 

22 Set, 

23 Tuple, 

24 Type, 

25 TypeVar, 

26 Union, 

27) 

28 

29from typing_extensions import Annotated 

30 

31from pydantic.v1.errors import ConfigError 

32from pydantic.v1.typing import ( 

33 NoneType, 

34 WithArgsTypes, 

35 all_literal_values, 

36 display_as_type, 

37 get_args, 

38 get_origin, 

39 is_literal_type, 

40 is_union, 

41) 

42from pydantic.v1.version import version_info 

43 

44if TYPE_CHECKING: 

45 from inspect import Signature 

46 from pathlib import Path 

47 

48 from pydantic.v1.config import BaseConfig 

49 from pydantic.v1.dataclasses import Dataclass 

50 from pydantic.v1.fields import ModelField 

51 from pydantic.v1.main import BaseModel 

52 from pydantic.v1.typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs 

53 

54 RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]] 

55 

56__all__ = ( 

57 'import_string', 

58 'sequence_like', 

59 'validate_field_name', 

60 'lenient_isinstance', 

61 'lenient_issubclass', 

62 'in_ipython', 

63 'is_valid_identifier', 

64 'deep_update', 

65 'update_not_none', 

66 'almost_equal_floats', 

67 'get_model', 

68 'to_camel', 

69 'to_lower_camel', 

70 'is_valid_field', 

71 'smart_deepcopy', 

72 'PyObjectStr', 

73 'Representation', 

74 'GetterDict', 

75 'ValueItems', 

76 'version_info', # required here to match behaviour in v1.3 

77 'ClassAttribute', 

78 'path_type', 

79 'ROOT_KEY', 

80 'get_unique_discriminator_alias', 

81 'get_discriminator_alias_and_values', 

82 'DUNDER_ATTRIBUTES', 

83) 

84 

85ROOT_KEY = '__root__' 

86# these are types that are returned unchanged by deepcopy 

87IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = { 

88 int, 

89 float, 

90 complex, 

91 str, 

92 bool, 

93 bytes, 

94 type, 

95 NoneType, 

96 FunctionType, 

97 BuiltinFunctionType, 

98 LambdaType, 

99 weakref.ref, 

100 CodeType, 

101 # note: including ModuleType will differ from behaviour of deepcopy by not producing error. 

102 # It might be not a good idea in general, but considering that this function used only internally 

103 # against default values of fields, this will allow to actually have a field with module as default value 

104 ModuleType, 

105 NotImplemented.__class__, 

106 Ellipsis.__class__, 

107} 

108 

109# these are types that if empty, might be copied with simple copy() instead of deepcopy() 

110BUILTIN_COLLECTIONS: Set[Type[Any]] = { 

111 list, 

112 set, 

113 tuple, 

114 frozenset, 

115 dict, 

116 OrderedDict, 

117 defaultdict, 

118 deque, 

119} 

120 

121 

122def import_string(dotted_path: str) -> Any: 

123 """ 

124 Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the 

125 last name in the path. Raise ImportError if the import fails. 

126 """ 

127 from importlib import import_module 

128 

129 try: 

130 module_path, class_name = dotted_path.strip(' ').rsplit('.', 1) 

131 except ValueError as e: 

132 raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e 

133 

134 module = import_module(module_path) 

135 try: 

136 return getattr(module, class_name) 

137 except AttributeError as e: 

138 raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e 

139 

140 

141def truncate(v: Union[str], *, max_len: int = 80) -> str: 

142 """ 

143 Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long 

144 """ 

145 warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning) 

146 if isinstance(v, str) and len(v) > (max_len - 2): 

147 # -3 so quote + string + … + quote has correct length 

148 return (v[: (max_len - 3)] + '…').__repr__() 

149 try: 

150 v = v.__repr__() 

151 except TypeError: 

152 v = v.__class__.__repr__(v) # in case v is a type 

153 if len(v) > max_len: 

154 v = v[: max_len - 1] + '…' 

155 return v 

156 

157 

158def sequence_like(v: Any) -> bool: 

159 return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque)) 

160 

161 

162def validate_field_name(bases: Iterable[Type[Any]], field_name: str) -> None: 

163 """ 

164 Ensure that the field's name does not shadow an existing attribute of the model. 

165 """ 

166 for base in bases: 

167 if getattr(base, field_name, None): 

168 raise NameError( 

169 f'Field name "{field_name}" shadows a BaseModel attribute; ' 

170 f'use a different field name with "alias=\'{field_name}\'".' 

171 ) 

172 

173 

174def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool: 

175 try: 

176 return isinstance(o, class_or_tuple) # type: ignore[arg-type] 

177 except TypeError: 

178 return False 

179 

180 

181def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool: 

182 try: 

183 return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type] 

184 except TypeError: 

185 if isinstance(cls, WithArgsTypes): 

186 return False 

187 raise # pragma: no cover 

188 

189 

190def in_ipython() -> bool: 

191 """ 

192 Check whether we're in an ipython environment, including jupyter notebooks. 

193 """ 

194 try: 

195 eval('__IPYTHON__') 

196 except NameError: 

197 return False 

198 else: # pragma: no cover 

199 return True 

200 

201 

202def is_valid_identifier(identifier: str) -> bool: 

203 """ 

204 Checks that a string is a valid identifier and not a Python keyword. 

205 :param identifier: The identifier to test. 

206 :return: True if the identifier is valid. 

207 """ 

208 return identifier.isidentifier() and not keyword.iskeyword(identifier) 

209 

210 

211KeyType = TypeVar('KeyType') 

212 

213 

214def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]: 

215 updated_mapping = mapping.copy() 

216 for updating_mapping in updating_mappings: 

217 for k, v in updating_mapping.items(): 

218 if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict): 

219 updated_mapping[k] = deep_update(updated_mapping[k], v) 

220 else: 

221 updated_mapping[k] = v 

222 return updated_mapping 

223 

224 

225def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None: 

226 mapping.update({k: v for k, v in update.items() if v is not None}) 

227 

228 

229def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool: 

230 """ 

231 Return True if two floats are almost equal 

232 """ 

233 return abs(value_1 - value_2) <= delta 

234 

235 

236def generate_model_signature( 

237 init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig'] 

238) -> 'Signature': 

239 """ 

240 Generate signature for model based on its fields 

241 """ 

242 from inspect import Parameter, Signature, signature 

243 

244 from pydantic.v1.config import Extra 

245 

246 present_params = signature(init).parameters.values() 

247 merged_params: Dict[str, Parameter] = {} 

248 var_kw = None 

249 use_var_kw = False 

250 

251 for param in islice(present_params, 1, None): # skip self arg 

252 if param.kind is param.VAR_KEYWORD: 

253 var_kw = param 

254 continue 

255 merged_params[param.name] = param 

256 

257 if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through 

258 allow_names = config.allow_population_by_field_name 

259 for field_name, field in fields.items(): 

260 param_name = field.alias 

261 if field_name in merged_params or param_name in merged_params: 

262 continue 

263 elif not is_valid_identifier(param_name): 

264 if allow_names and is_valid_identifier(field_name): 

265 param_name = field_name 

266 else: 

267 use_var_kw = True 

268 continue 

269 

270 # TODO: replace annotation with actual expected types once #1055 solved 

271 kwargs = {'default': field.default} if not field.required else {} 

272 merged_params[param_name] = Parameter( 

273 param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs 

274 ) 

275 

276 if config.extra is Extra.allow: 

277 use_var_kw = True 

278 

279 if var_kw and use_var_kw: 

280 # Make sure the parameter for extra kwargs 

281 # does not have the same name as a field 

282 default_model_signature = [ 

283 ('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD), 

284 ('data', Parameter.VAR_KEYWORD), 

285 ] 

286 if [(p.name, p.kind) for p in present_params] == default_model_signature: 

287 # if this is the standard model signature, use extra_data as the extra args name 

288 var_kw_name = 'extra_data' 

289 else: 

290 # else start from var_kw 

291 var_kw_name = var_kw.name 

292 

293 # generate a name that's definitely unique 

294 while var_kw_name in fields: 

295 var_kw_name += '_' 

296 merged_params[var_kw_name] = var_kw.replace(name=var_kw_name) 

297 

298 return Signature(parameters=list(merged_params.values()), return_annotation=None) 

299 

300 

301def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']: 

302 from pydantic.v1.main import BaseModel 

303 

304 try: 

305 model_cls = obj.__pydantic_model__ # type: ignore 

306 except AttributeError: 

307 model_cls = obj 

308 

309 if not issubclass(model_cls, BaseModel): 

310 raise TypeError('Unsupported type, must be either BaseModel or dataclass') 

311 return model_cls 

312 

313 

314def to_camel(string: str) -> str: 

315 return ''.join(word.capitalize() for word in string.split('_')) 

316 

317 

318def to_lower_camel(string: str) -> str: 

319 if len(string) >= 1: 

320 pascal_string = to_camel(string) 

321 return pascal_string[0].lower() + pascal_string[1:] 

322 return string.lower() 

323 

324 

325T = TypeVar('T') 

326 

327 

328def unique_list( 

329 input_list: Union[List[T], Tuple[T, ...]], 

330 *, 

331 name_factory: Callable[[T], str] = str, 

332) -> List[T]: 

333 """ 

334 Make a list unique while maintaining order. 

335 We update the list if another one with the same name is set 

336 (e.g. root validator overridden in subclass) 

337 """ 

338 result: List[T] = [] 

339 result_names: List[str] = [] 

340 for v in input_list: 

341 v_name = name_factory(v) 

342 if v_name not in result_names: 

343 result_names.append(v_name) 

344 result.append(v) 

345 else: 

346 result[result_names.index(v_name)] = v 

347 

348 return result 

349 

350 

351class PyObjectStr(str): 

352 """ 

353 String class where repr doesn't include quotes. Useful with Representation when you want to return a string 

354 representation of something that valid (or pseudo-valid) python. 

355 """ 

356 

357 def __repr__(self) -> str: 

358 return str(self) 

359 

360 

361class Representation: 

362 """ 

363 Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details. 

364 

365 __pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations 

366 of objects. 

367 """ 

368 

369 __slots__: Tuple[str, ...] = tuple() 

370 

371 def __repr_args__(self) -> 'ReprArgs': 

372 """ 

373 Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden. 

374 

375 Can either return: 

376 * name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]` 

377 * or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]` 

378 """ 

379 attrs = ((s, getattr(self, s)) for s in self.__slots__) 

380 return [(a, v) for a, v in attrs if v is not None] 

381 

382 def __repr_name__(self) -> str: 

383 """ 

384 Name of the instance's class, used in __repr__. 

385 """ 

386 return self.__class__.__name__ 

387 

388 def __repr_str__(self, join_str: str) -> str: 

389 return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__()) 

390 

391 def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]: 

392 """ 

393 Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects 

394 """ 

395 yield self.__repr_name__() + '(' 

396 yield 1 

397 for name, value in self.__repr_args__(): 

398 if name is not None: 

399 yield name + '=' 

400 yield fmt(value) 

401 yield ',' 

402 yield 0 

403 yield -1 

404 yield ')' 

405 

406 def __str__(self) -> str: 

407 return self.__repr_str__(' ') 

408 

409 def __repr__(self) -> str: 

410 return f'{self.__repr_name__()}({self.__repr_str__(", ")})' 

411 

412 def __rich_repr__(self) -> 'RichReprResult': 

413 """Get fields for Rich library""" 

414 for name, field_repr in self.__repr_args__(): 

415 if name is None: 

416 yield field_repr 

417 else: 

418 yield name, field_repr 

419 

420 

421class GetterDict(Representation): 

422 """ 

423 Hack to make object's smell just enough like dicts for validate_model. 

424 

425 We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves. 

426 """ 

427 

428 __slots__ = ('_obj',) 

429 

430 def __init__(self, obj: Any): 

431 self._obj = obj 

432 

433 def __getitem__(self, key: str) -> Any: 

434 try: 

435 return getattr(self._obj, key) 

436 except AttributeError as e: 

437 raise KeyError(key) from e 

438 

439 def get(self, key: Any, default: Any = None) -> Any: 

440 return getattr(self._obj, key, default) 

441 

442 def extra_keys(self) -> Set[Any]: 

443 """ 

444 We don't want to get any other attributes of obj if the model didn't explicitly ask for them 

445 """ 

446 return set() 

447 

448 def keys(self) -> List[Any]: 

449 """ 

450 Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python 

451 dictionaries. 

452 """ 

453 return list(self) 

454 

455 def values(self) -> List[Any]: 

456 return [self[k] for k in self] 

457 

458 def items(self) -> Iterator[Tuple[str, Any]]: 

459 for k in self: 

460 yield k, self.get(k) 

461 

462 def __iter__(self) -> Iterator[str]: 

463 for name in dir(self._obj): 

464 if not name.startswith('_'): 

465 yield name 

466 

467 def __len__(self) -> int: 

468 return sum(1 for _ in self) 

469 

470 def __contains__(self, item: Any) -> bool: 

471 return item in self.keys() 

472 

473 def __eq__(self, other: Any) -> bool: 

474 return dict(self) == dict(other.items()) 

475 

476 def __repr_args__(self) -> 'ReprArgs': 

477 return [(None, dict(self))] 

478 

479 def __repr_name__(self) -> str: 

480 return f'GetterDict[{display_as_type(self._obj)}]' 

481 

482 

483class ValueItems(Representation): 

484 """ 

485 Class for more convenient calculation of excluded or included fields on values. 

486 """ 

487 

488 __slots__ = ('_items', '_type') 

489 

490 def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None: 

491 items = self._coerce_items(items) 

492 

493 if isinstance(value, (list, tuple)): 

494 items = self._normalize_indexes(items, len(value)) 

495 

496 self._items: 'MappingIntStrAny' = items 

497 

498 def is_excluded(self, item: Any) -> bool: 

499 """ 

500 Check if item is fully excluded. 

501 

502 :param item: key or index of a value 

503 """ 

504 return self.is_true(self._items.get(item)) 

505 

506 def is_included(self, item: Any) -> bool: 

507 """ 

508 Check if value is contained in self._items 

509 

510 :param item: key or index of value 

511 """ 

512 return item in self._items 

513 

514 def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]: 

515 """ 

516 :param e: key or index of element on value 

517 :return: raw values for element if self._items is dict and contain needed element 

518 """ 

519 

520 item = self._items.get(e) 

521 return item if not self.is_true(item) else None 

522 

523 def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny': 

524 """ 

525 :param items: dict or set of indexes which will be normalized 

526 :param v_length: length of sequence indexes of which will be 

527 

528 >>> self._normalize_indexes({0: True, -2: True, -1: True}, 4) 

529 {0: True, 2: True, 3: True} 

530 >>> self._normalize_indexes({'__all__': True}, 4) 

531 {0: True, 1: True, 2: True, 3: True} 

532 """ 

533 

534 normalized_items: 'DictIntStrAny' = {} 

535 all_items = None 

536 for i, v in items.items(): 

537 if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)): 

538 raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}') 

539 if i == '__all__': 

540 all_items = self._coerce_value(v) 

541 continue 

542 if not isinstance(i, int): 

543 raise TypeError( 

544 'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: ' 

545 'expected integer keys or keyword "__all__"' 

546 ) 

547 normalized_i = v_length + i if i < 0 else i 

548 normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i)) 

549 

550 if not all_items: 

551 return normalized_items 

552 if self.is_true(all_items): 

553 for i in range(v_length): 

554 normalized_items.setdefault(i, ...) 

555 return normalized_items 

556 for i in range(v_length): 

557 normalized_item = normalized_items.setdefault(i, {}) 

558 if not self.is_true(normalized_item): 

559 normalized_items[i] = self.merge(all_items, normalized_item) 

560 return normalized_items 

561 

562 @classmethod 

563 def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any: 

564 """ 

565 Merge a ``base`` item with an ``override`` item. 

566 

567 Both ``base`` and ``override`` are converted to dictionaries if possible. 

568 Sets are converted to dictionaries with the sets entries as keys and 

569 Ellipsis as values. 

570 

571 Each key-value pair existing in ``base`` is merged with ``override``, 

572 while the rest of the key-value pairs are updated recursively with this function. 

573 

574 Merging takes place based on the "union" of keys if ``intersect`` is 

575 set to ``False`` (default) and on the intersection of keys if 

576 ``intersect`` is set to ``True``. 

577 """ 

578 override = cls._coerce_value(override) 

579 base = cls._coerce_value(base) 

580 if override is None: 

581 return base 

582 if cls.is_true(base) or base is None: 

583 return override 

584 if cls.is_true(override): 

585 return base if intersect else override 

586 

587 # intersection or union of keys while preserving ordering: 

588 if intersect: 

589 merge_keys = [k for k in base if k in override] + [k for k in override if k in base] 

590 else: 

591 merge_keys = list(base) + [k for k in override if k not in base] 

592 

593 merged: 'DictIntStrAny' = {} 

594 for k in merge_keys: 

595 merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect) 

596 if merged_item is not None: 

597 merged[k] = merged_item 

598 

599 return merged 

600 

601 @staticmethod 

602 def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny': 

603 if isinstance(items, Mapping): 

604 pass 

605 elif isinstance(items, AbstractSet): 

606 items = dict.fromkeys(items, ...) 

607 else: 

608 class_name = getattr(items, '__class__', '???') 

609 assert_never( 

610 items, 

611 f'Unexpected type of exclude value {class_name}', 

612 ) 

613 return items 

614 

615 @classmethod 

616 def _coerce_value(cls, value: Any) -> Any: 

617 if value is None or cls.is_true(value): 

618 return value 

619 return cls._coerce_items(value) 

620 

621 @staticmethod 

622 def is_true(v: Any) -> bool: 

623 return v is True or v is ... 

624 

625 def __repr_args__(self) -> 'ReprArgs': 

626 return [(None, self._items)] 

627 

628 

629class ClassAttribute: 

630 """ 

631 Hide class attribute from its instances 

632 """ 

633 

634 __slots__ = ( 

635 'name', 

636 'value', 

637 ) 

638 

639 def __init__(self, name: str, value: Any) -> None: 

640 self.name = name 

641 self.value = value 

642 

643 def __get__(self, instance: Any, owner: Type[Any]) -> None: 

644 if instance is None: 

645 return self.value 

646 raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only') 

647 

648 

649path_types = { 

650 'is_dir': 'directory', 

651 'is_file': 'file', 

652 'is_mount': 'mount point', 

653 'is_symlink': 'symlink', 

654 'is_block_device': 'block device', 

655 'is_char_device': 'char device', 

656 'is_fifo': 'FIFO', 

657 'is_socket': 'socket', 

658} 

659 

660 

661def path_type(p: 'Path') -> str: 

662 """ 

663 Find out what sort of thing a path is. 

664 """ 

665 assert p.exists(), 'path does not exist' 

666 for method, name in path_types.items(): 

667 if getattr(p, method)(): 

668 return name 

669 

670 return 'unknown' 

671 

672 

673Obj = TypeVar('Obj') 

674 

675 

676def smart_deepcopy(obj: Obj) -> Obj: 

677 """ 

678 Return type as is for immutable built-in types 

679 Use obj.copy() for built-in empty collections 

680 Use copy.deepcopy() for non-empty collections and unknown objects 

681 """ 

682 

683 obj_type = obj.__class__ 

684 if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES: 

685 return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway 

686 try: 

687 if not obj and obj_type in BUILTIN_COLLECTIONS: 

688 # faster way for empty collections, no need to copy its members 

689 return obj if obj_type is tuple else obj.copy() # type: ignore # tuple doesn't have copy method 

690 except (TypeError, ValueError, RuntimeError): 

691 # do we really dare to catch ALL errors? Seems a bit risky 

692 pass 

693 

694 return deepcopy(obj) # slowest way when we actually might need a deepcopy 

695 

696 

697def is_valid_field(name: str) -> bool: 

698 if not name.startswith('_'): 

699 return True 

700 return ROOT_KEY == name 

701 

702 

703DUNDER_ATTRIBUTES = { 

704 '__annotations__', 

705 '__classcell__', 

706 '__doc__', 

707 '__module__', 

708 '__orig_bases__', 

709 '__orig_class__', 

710 '__qualname__', 

711 '__firstlineno__', 

712 '__static_attributes__', 

713} 

714 

715 

716def is_valid_private_name(name: str) -> bool: 

717 return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES 

718 

719 

720_EMPTY = object() 

721 

722 

723def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool: 

724 """ 

725 Check that the items of `left` are the same objects as those in `right`. 

726 

727 >>> a, b = object(), object() 

728 >>> all_identical([a, b, a], [a, b, a]) 

729 True 

730 >>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical" 

731 False 

732 """ 

733 for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY): 

734 if left_item is not right_item: 

735 return False 

736 return True 

737 

738 

739def assert_never(obj: NoReturn, msg: str) -> NoReturn: 

740 """ 

741 Helper to make sure that we have covered all possible types. 

742 

743 This is mostly useful for ``mypy``, docs: 

744 https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks 

745 """ 

746 raise TypeError(msg) 

747 

748 

749def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str: 

750 """Validate that all aliases are the same and if that's the case return the alias""" 

751 unique_aliases = set(all_aliases) 

752 if len(unique_aliases) > 1: 

753 raise ConfigError( 

754 f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})' 

755 ) 

756 return unique_aliases.pop() 

757 

758 

759def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]: 

760 """ 

761 Get alias and all valid values in the `Literal` type of the discriminator field 

762 `tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many. 

763 """ 

764 is_root_model = getattr(tp, '__custom_root_type__', False) 

765 

766 if get_origin(tp) is Annotated: 

767 tp = get_args(tp)[0] 

768 

769 if hasattr(tp, '__pydantic_model__'): 

770 tp = tp.__pydantic_model__ 

771 

772 if is_union(get_origin(tp)): 

773 alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key) 

774 return alias, tuple(v for values in all_values for v in values) 

775 elif is_root_model: 

776 union_type = tp.__fields__[ROOT_KEY].type_ 

777 alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key) 

778 

779 if len(set(all_values)) > 1: 

780 raise ConfigError( 

781 f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}' 

782 ) 

783 

784 return alias, all_values[0] 

785 

786 else: 

787 try: 

788 t_discriminator_type = tp.__fields__[discriminator_key].type_ 

789 except AttributeError as e: 

790 raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e 

791 except KeyError as e: 

792 raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e 

793 

794 if not is_literal_type(t_discriminator_type): 

795 raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`') 

796 

797 return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type) 

798 

799 

800def _get_union_alias_and_all_values( 

801 union_type: Type[Any], discriminator_key: str 

802) -> Tuple[str, Tuple[Tuple[str, ...], ...]]: 

803 zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)] 

804 # unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))] 

805 all_aliases, all_values = zip(*zipped_aliases_values) 

806 return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values