1# util/typing.py
2# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import builtins
12from collections import deque
13import collections.abc as collections_abc
14import re
15import sys
16import typing
17from typing import Any
18from typing import Callable
19from typing import Dict
20from typing import ForwardRef
21from typing import Generic
22from typing import Iterable
23from typing import Mapping
24from typing import NewType
25from typing import NoReturn
26from typing import Optional
27from typing import overload
28from typing import Set
29from typing import Tuple
30from typing import Type
31from typing import TYPE_CHECKING
32from typing import TypeVar
33from typing import Union
34
35import typing_extensions
36
37from . import compat
38
39if True: # zimports removes the tailing comments
40 from typing_extensions import Annotated as Annotated # 3.8
41 from typing_extensions import Concatenate as Concatenate # 3.10
42 from typing_extensions import (
43 dataclass_transform as dataclass_transform, # 3.11,
44 )
45 from typing_extensions import Final as Final # 3.8
46 from typing_extensions import final as final # 3.8
47 from typing_extensions import get_args as get_args # 3.10
48 from typing_extensions import get_origin as get_origin # 3.10
49 from typing_extensions import Literal as Literal # 3.8
50 from typing_extensions import NotRequired as NotRequired # 3.11
51 from typing_extensions import ParamSpec as ParamSpec # 3.10
52 from typing_extensions import Protocol as Protocol # 3.8
53 from typing_extensions import SupportsIndex as SupportsIndex # 3.8
54 from typing_extensions import TypeAlias as TypeAlias # 3.10
55 from typing_extensions import TypedDict as TypedDict # 3.8
56 from typing_extensions import TypeGuard as TypeGuard # 3.10
57 from typing_extensions import Self as Self # 3.11
58 from typing_extensions import TypeAliasType as TypeAliasType # 3.12
59 from typing_extensions import Never as Never # 3.11
60 from typing_extensions import LiteralString as LiteralString # 3.11
61
62_T = TypeVar("_T", bound=Any)
63_KT = TypeVar("_KT")
64_KT_co = TypeVar("_KT_co", covariant=True)
65_KT_contra = TypeVar("_KT_contra", contravariant=True)
66_VT = TypeVar("_VT")
67_VT_co = TypeVar("_VT_co", covariant=True)
68
69if compat.py310:
70 # why they took until py310 to put this in stdlib is beyond me,
71 # I've been wanting it since py27
72 from types import NoneType as NoneType
73else:
74 NoneType = type(None) # type: ignore
75
76
77def is_fwd_none(typ: Any) -> bool:
78 return isinstance(typ, ForwardRef) and typ.__forward_arg__ == "None"
79
80
81_AnnotationScanType = Union[
82 Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
83]
84
85
86class ArgsTypeProtocol(Protocol):
87 """protocol for types that have ``__args__``
88
89 there's no public interface for this AFAIK
90
91 """
92
93 __args__: Tuple[_AnnotationScanType, ...]
94
95
96class GenericProtocol(Protocol[_T]):
97 """protocol for generic types.
98
99 this since Python.typing _GenericAlias is private
100
101 """
102
103 __args__: Tuple[_AnnotationScanType, ...]
104 __origin__: Type[_T]
105
106 # Python's builtin _GenericAlias has this method, however builtins like
107 # list, dict, etc. do not, even though they have ``__origin__`` and
108 # ``__args__``
109 #
110 # def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
111 # ...
112
113
114# copied from TypeShed, required in order to implement
115# MutableMapping.update()
116class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
117 def keys(self) -> Iterable[_KT]: ...
118
119 def __getitem__(self, __k: _KT) -> _VT_co: ...
120
121
122# work around https://github.com/microsoft/pyright/issues/3025
123_LiteralStar = Literal["*"]
124
125
126def de_stringify_annotation(
127 cls: Type[Any],
128 annotation: _AnnotationScanType,
129 originating_module: str,
130 locals_: Mapping[str, Any],
131 *,
132 str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
133 include_generic: bool = False,
134 _already_seen: Optional[Set[Any]] = None,
135) -> Type[Any]:
136 """Resolve annotations that may be string based into real objects.
137
138 This is particularly important if a module defines "from __future__ import
139 annotations", as everything inside of __annotations__ is a string. We want
140 to at least have generic containers like ``Mapped``, ``Union``, ``List``,
141 etc.
142
143 """
144 # looked at typing.get_type_hints(), looked at pydantic. We need much
145 # less here, and we here try to not use any private typing internals
146 # or construct ForwardRef objects which is documented as something
147 # that should be avoided.
148
149 original_annotation = annotation
150
151 if is_fwd_ref(annotation):
152 annotation = annotation.__forward_arg__
153
154 if isinstance(annotation, str):
155 if str_cleanup_fn:
156 annotation = str_cleanup_fn(annotation, originating_module)
157
158 annotation = eval_expression(
159 annotation, originating_module, locals_=locals_, in_class=cls
160 )
161
162 if (
163 include_generic
164 and is_generic(annotation)
165 and not is_literal(annotation)
166 ):
167 if _already_seen is None:
168 _already_seen = set()
169
170 if annotation in _already_seen:
171 # only occurs recursively. outermost return type
172 # will always be Type.
173 # the element here will be either ForwardRef or
174 # Optional[ForwardRef]
175 return original_annotation # type: ignore
176 else:
177 _already_seen.add(annotation)
178
179 elements = tuple(
180 de_stringify_annotation(
181 cls,
182 elem,
183 originating_module,
184 locals_,
185 str_cleanup_fn=str_cleanup_fn,
186 include_generic=include_generic,
187 _already_seen=_already_seen,
188 )
189 for elem in annotation.__args__
190 )
191
192 return _copy_generic_annotation_with(annotation, elements)
193
194 return annotation # type: ignore
195
196
197def fixup_container_fwd_refs(
198 type_: _AnnotationScanType,
199) -> _AnnotationScanType:
200 """Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')]
201 and similar for list, set
202
203 """
204
205 if (
206 is_generic(type_)
207 and get_origin(type_)
208 in (
209 dict,
210 set,
211 list,
212 collections_abc.MutableSet,
213 collections_abc.MutableMapping,
214 collections_abc.MutableSequence,
215 collections_abc.Mapping,
216 collections_abc.Sequence,
217 )
218 # fight, kick and scream to struggle to tell the difference between
219 # dict[] and typing.Dict[] which DO NOT compare the same and DO NOT
220 # behave the same yet there is NO WAY to distinguish between which type
221 # it is using public attributes
222 and not re.match(
223 "typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_)
224 )
225 ):
226 # compat with py3.10 and earlier
227 return get_origin(type_).__class_getitem__( # type: ignore
228 tuple(
229 [
230 ForwardRef(elem) if isinstance(elem, str) else elem
231 for elem in get_args(type_)
232 ]
233 )
234 )
235 return type_
236
237
238def _copy_generic_annotation_with(
239 annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
240) -> Type[_T]:
241 if hasattr(annotation, "copy_with"):
242 # List, Dict, etc. real generics
243 return annotation.copy_with(elements) # type: ignore
244 else:
245 # Python builtins list, dict, etc.
246 return annotation.__origin__[elements] # type: ignore
247
248
249def eval_expression(
250 expression: str,
251 module_name: str,
252 *,
253 locals_: Optional[Mapping[str, Any]] = None,
254 in_class: Optional[Type[Any]] = None,
255) -> Any:
256 try:
257 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
258 except KeyError as ke:
259 raise NameError(
260 f"Module {module_name} isn't present in sys.modules; can't "
261 f"evaluate expression {expression}"
262 ) from ke
263
264 try:
265 if in_class is not None:
266 cls_namespace = dict(in_class.__dict__)
267 cls_namespace.setdefault(in_class.__name__, in_class)
268
269 # see #10899. We want the locals/globals to take precedence
270 # over the class namespace in this context, even though this
271 # is not the usual way variables would resolve.
272 cls_namespace.update(base_globals)
273
274 annotation = eval(expression, cls_namespace, locals_)
275 else:
276 annotation = eval(expression, base_globals, locals_)
277 except Exception as err:
278 raise NameError(
279 f"Could not de-stringify annotation {expression!r}"
280 ) from err
281 else:
282 return annotation
283
284
285def eval_name_only(
286 name: str,
287 module_name: str,
288 *,
289 locals_: Optional[Mapping[str, Any]] = None,
290) -> Any:
291 if "." in name:
292 return eval_expression(name, module_name, locals_=locals_)
293
294 try:
295 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
296 except KeyError as ke:
297 raise NameError(
298 f"Module {module_name} isn't present in sys.modules; can't "
299 f"resolve name {name}"
300 ) from ke
301
302 # name only, just look in globals. eval() works perfectly fine here,
303 # however we are seeking to have this be faster, as this occurs for
304 # every Mapper[] keyword, etc. depending on configuration
305 try:
306 return base_globals[name]
307 except KeyError as ke:
308 # check in builtins as well to handle `list`, `set` or `dict`, etc.
309 try:
310 return builtins.__dict__[name]
311 except KeyError:
312 pass
313
314 raise NameError(
315 f"Could not locate name {name} in module {module_name}"
316 ) from ke
317
318
319def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
320 try:
321 obj = eval_name_only(name, module_name)
322 except NameError:
323 return name
324 else:
325 return getattr(obj, "__name__", name)
326
327
328def is_pep593(type_: Optional[Any]) -> bool:
329 return type_ is not None and get_origin(type_) in _type_tuples.Annotated
330
331
332def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
333 return isinstance(obj, collections_abc.Iterable) and not isinstance(
334 obj, (str, bytes)
335 )
336
337
338def is_literal(type_: Any) -> bool:
339 return get_origin(type_) in _type_tuples.Literal
340
341
342def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
343 return hasattr(type_, "__supertype__")
344
345 # doesn't work in 3.8, 3.7 as it passes a closure, not an
346 # object instance
347 # isinstance(type, type_instances.NewType)
348
349
350def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
351 return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
352
353
354def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
355 # NOTE: a generic TAT does not instance check as TypeAliasType outside of
356 # python 3.10. For sqlalchemy use cases it's fine to consider it a TAT
357 # though.
358 # NOTE: things seems to work also without this additional check
359 if is_generic(type_):
360 return is_pep695(type_.__origin__)
361 return isinstance(type_, _type_instances.TypeAliasType)
362
363
364def flatten_newtype(type_: NewType) -> Type[Any]:
365 super_type = type_.__supertype__
366 while is_newtype(super_type):
367 super_type = super_type.__supertype__
368 return super_type # type: ignore[return-value]
369
370
371def pep695_values(type_: _AnnotationScanType) -> Set[Any]:
372 """Extracts the value from a TypeAliasType, recursively exploring unions
373 and inner TypeAliasType to flatten them into a single set.
374
375 Forward references are not evaluated, so no recursive exploration happens
376 into them.
377 """
378 _seen = set()
379
380 def recursive_value(inner_type):
381 if inner_type in _seen:
382 # recursion are not supported (at least it's flagged as
383 # an error by pyright). Just avoid infinite loop
384 return inner_type
385 _seen.add(inner_type)
386 if not is_pep695(inner_type):
387 return inner_type
388 value = inner_type.__value__
389 if not is_union(value):
390 return value
391 return [recursive_value(t) for t in value.__args__]
392
393 res = recursive_value(type_)
394 if isinstance(res, list):
395 types = set()
396 stack = deque(res)
397 while stack:
398 t = stack.popleft()
399 if isinstance(t, list):
400 stack.extend(t)
401 else:
402 types.add(None if t is NoneType or is_fwd_none(t) else t)
403 return types
404 else:
405 return {res}
406
407
408def is_fwd_ref(
409 type_: _AnnotationScanType,
410 check_generic: bool = False,
411 check_for_plain_string: bool = False,
412) -> TypeGuard[ForwardRef]:
413 if check_for_plain_string and isinstance(type_, str):
414 return True
415 elif isinstance(type_, _type_instances.ForwardRef):
416 return True
417 elif check_generic and is_generic(type_):
418 return any(
419 is_fwd_ref(
420 arg, True, check_for_plain_string=check_for_plain_string
421 )
422 for arg in type_.__args__
423 )
424 else:
425 return False
426
427
428@overload
429def de_optionalize_union_types(type_: str) -> str: ...
430
431
432@overload
433def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ...
434
435
436@overload
437def de_optionalize_union_types(
438 type_: _AnnotationScanType,
439) -> _AnnotationScanType: ...
440
441
442def de_optionalize_union_types(
443 type_: _AnnotationScanType,
444) -> _AnnotationScanType:
445 """Given a type, filter out ``Union`` types that include ``NoneType``
446 to not include the ``NoneType``.
447
448 Contains extra logic to work on non-flattened unions, unions that contain
449 ``None`` (seen in py38, 37)
450
451 """
452
453 if is_fwd_ref(type_):
454 return _de_optionalize_fwd_ref_union_types(type_, False)
455
456 elif is_union(type_) and includes_none(type_):
457 if compat.py39:
458 typ = set(type_.__args__)
459 else:
460 # py38, 37 - unions are not automatically flattened, can contain
461 # None rather than NoneType
462 stack_of_unions = deque([type_])
463 typ = set()
464 while stack_of_unions:
465 u_typ = stack_of_unions.popleft()
466 for elem in u_typ.__args__:
467 if is_union(elem):
468 stack_of_unions.append(elem)
469 else:
470 typ.add(elem)
471
472 typ.discard(None) # type: ignore
473
474 typ = {t for t in typ if t is not NoneType and not is_fwd_none(t)}
475
476 return make_union_type(*typ)
477
478 else:
479 return type_
480
481
482@overload
483def _de_optionalize_fwd_ref_union_types(
484 type_: ForwardRef, return_has_none: Literal[True]
485) -> bool: ...
486
487
488@overload
489def _de_optionalize_fwd_ref_union_types(
490 type_: ForwardRef, return_has_none: Literal[False]
491) -> _AnnotationScanType: ...
492
493
494def _de_optionalize_fwd_ref_union_types(
495 type_: ForwardRef, return_has_none: bool
496) -> Union[_AnnotationScanType, bool]:
497 """return the non-optional type for Optional[], Union[None, ...], x|None,
498 etc. without de-stringifying forward refs.
499
500 unfortunately this seems to require lots of hardcoded heuristics
501
502 """
503
504 annotation = type_.__forward_arg__
505
506 mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
507 if mm:
508 g1 = mm.group(1).split(".")[-1]
509 if g1 == "Optional":
510 return True if return_has_none else ForwardRef(mm.group(2))
511 elif g1 == "Union":
512 if "[" in mm.group(2):
513 # cases like "Union[Dict[str, int], int, None]"
514 elements: list[str] = []
515 current: list[str] = []
516 ignore_comma = 0
517 for char in mm.group(2):
518 if char == "[":
519 ignore_comma += 1
520 elif char == "]":
521 ignore_comma -= 1
522 elif ignore_comma == 0 and char == ",":
523 elements.append("".join(current).strip())
524 current.clear()
525 continue
526 current.append(char)
527 else:
528 elements = re.split(r",\s*", mm.group(2))
529 parts = [ForwardRef(elem) for elem in elements if elem != "None"]
530 if return_has_none:
531 return len(elements) != len(parts)
532 else:
533 return make_union_type(*parts) if parts else Never # type: ignore[return-value] # noqa: E501
534 else:
535 return False if return_has_none else type_
536
537 pipe_tokens = re.split(r"\s*\|\s*", annotation)
538 has_none = "None" in pipe_tokens
539 if return_has_none:
540 return has_none
541 if has_none:
542 anno_str = "|".join(p for p in pipe_tokens if p != "None")
543 return ForwardRef(anno_str) if anno_str else Never # type: ignore[return-value] # noqa: E501
544
545 return type_
546
547
548def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
549 """Make a Union type."""
550
551 return Union[types] # type: ignore
552
553
554def includes_none(type_: Any) -> bool:
555 """Returns if the type annotation ``type_`` allows ``None``.
556
557 This function supports:
558 * forward refs
559 * unions
560 * pep593 - Annotated
561 * pep695 - TypeAliasType (does not support looking into
562 fw reference of other pep695)
563 * NewType
564 * plain types like ``int``, ``None``, etc
565 """
566 if is_fwd_ref(type_):
567 return _de_optionalize_fwd_ref_union_types(type_, True)
568 if is_union(type_):
569 return any(includes_none(t) for t in get_args(type_))
570 if is_pep593(type_):
571 return includes_none(get_args(type_)[0])
572 if is_pep695(type_):
573 return any(includes_none(t) for t in pep695_values(type_))
574 if is_newtype(type_):
575 return includes_none(type_.__supertype__)
576 try:
577 return type_ in (NoneType, None) or is_fwd_none(type_)
578 except TypeError:
579 # if type_ is Column, mapped_column(), etc. the use of "in"
580 # resolves to ``__eq__()`` which then gives us an expression object
581 # that can't resolve to boolean. just catch it all via exception
582 return False
583
584
585def is_a_type(type_: Any) -> bool:
586 return (
587 isinstance(type_, type)
588 or hasattr(type_, "__origin__")
589 or type_.__module__ in ("typing", "typing_extensions")
590 or type(type_).__mro__[0].__module__ in ("typing", "typing_extensions")
591 )
592
593
594def is_union(type_: Any) -> TypeGuard[ArgsTypeProtocol]:
595 return is_origin_of(type_, "Union", "UnionType")
596
597
598def is_origin_of_cls(
599 type_: Any, class_obj: Union[Tuple[Type[Any], ...], Type[Any]]
600) -> bool:
601 """return True if the given type has an __origin__ that shares a base
602 with the given class"""
603
604 origin = get_origin(type_)
605 if origin is None:
606 return False
607
608 return isinstance(origin, type) and issubclass(origin, class_obj)
609
610
611def is_origin_of(
612 type_: Any, *names: str, module: Optional[str] = None
613) -> bool:
614 """return True if the given type has an __origin__ with the given name
615 and optional module."""
616
617 origin = get_origin(type_)
618 if origin is None:
619 return False
620
621 return _get_type_name(origin) in names and (
622 module is None or origin.__module__.startswith(module)
623 )
624
625
626def _get_type_name(type_: Type[Any]) -> str:
627 if compat.py310:
628 return type_.__name__
629 else:
630 typ_name = getattr(type_, "__name__", None)
631 if typ_name is None:
632 typ_name = getattr(type_, "_name", None)
633
634 return typ_name # type: ignore
635
636
637class DescriptorProto(Protocol):
638 def __get__(self, instance: object, owner: Any) -> Any: ...
639
640 def __set__(self, instance: Any, value: Any) -> None: ...
641
642 def __delete__(self, instance: Any) -> None: ...
643
644
645_DESC = TypeVar("_DESC", bound=DescriptorProto)
646
647
648class DescriptorReference(Generic[_DESC]):
649 """a descriptor that refers to a descriptor.
650
651 used for cases where we need to have an instance variable referring to an
652 object that is itself a descriptor, which typically confuses typing tools
653 as they don't know when they should use ``__get__`` or not when referring
654 to the descriptor assignment as an instance variable. See
655 sqlalchemy.orm.interfaces.PropComparator.prop
656
657 """
658
659 if TYPE_CHECKING:
660
661 def __get__(self, instance: object, owner: Any) -> _DESC: ...
662
663 def __set__(self, instance: Any, value: _DESC) -> None: ...
664
665 def __delete__(self, instance: Any) -> None: ...
666
667
668_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
669
670
671class RODescriptorReference(Generic[_DESC_co]):
672 """a descriptor that refers to a descriptor.
673
674 same as :class:`.DescriptorReference` but is read-only, so that subclasses
675 can define a subtype as the generically contained element
676
677 """
678
679 if TYPE_CHECKING:
680
681 def __get__(self, instance: object, owner: Any) -> _DESC_co: ...
682
683 def __set__(self, instance: Any, value: Any) -> NoReturn: ...
684
685 def __delete__(self, instance: Any) -> NoReturn: ...
686
687
688_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
689
690
691class CallableReference(Generic[_FN]):
692 """a descriptor that refers to a callable.
693
694 works around mypy's limitation of not allowing callables assigned
695 as instance variables
696
697
698 """
699
700 if TYPE_CHECKING:
701
702 def __get__(self, instance: object, owner: Any) -> _FN: ...
703
704 def __set__(self, instance: Any, value: _FN) -> None: ...
705
706 def __delete__(self, instance: Any) -> None: ...
707
708
709class _TypingInstances:
710 def __getattr__(self, key: str) -> tuple[type, ...]:
711 types = tuple(
712 {
713 t
714 for t in [
715 getattr(typing, key, None),
716 getattr(typing_extensions, key, None),
717 ]
718 if t is not None
719 }
720 )
721 if not types:
722 raise AttributeError(key)
723 self.__dict__[key] = types
724 return types
725
726
727_type_tuples = _TypingInstances()
728if TYPE_CHECKING:
729 _type_instances = typing_extensions
730else:
731 _type_instances = _type_tuples
732
733LITERAL_TYPES = _type_tuples.Literal