1# util/typing.py
2# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import builtins
12from collections import deque
13import collections.abc as collections_abc
14import re
15import sys
16import typing
17from typing import Any
18from typing import Callable
19from typing import Dict
20from typing import ForwardRef
21from typing import Generic
22from typing import Iterable
23from typing import Mapping
24from typing import NewType
25from typing import NoReturn
26from typing import Optional
27from typing import overload
28from typing import Protocol
29from typing import Set
30from typing import Tuple
31from typing import Type
32from typing import TYPE_CHECKING
33from typing import TypeVar
34from typing import Union
35
36import typing_extensions
37
38from . import compat
39
40if True: # zimports removes the tailing comments
41 from typing_extensions import Annotated as Annotated # 3.9
42 from typing_extensions import Concatenate as Concatenate # 3.10
43 from typing_extensions import (
44 dataclass_transform as dataclass_transform, # 3.11,
45 )
46 from typing_extensions import get_args as get_args # 3.10
47 from typing_extensions import get_origin as get_origin # 3.10
48 from typing_extensions import (
49 Literal as Literal,
50 ) # 3.8 but has bugs before 3.10
51 from typing_extensions import NotRequired as NotRequired # 3.11
52 from typing_extensions import ParamSpec as ParamSpec # 3.10
53 from typing_extensions import TypeAlias as TypeAlias # 3.10
54 from typing_extensions import TypeGuard as TypeGuard # 3.10
55 from typing_extensions import TypeVarTuple as TypeVarTuple # 3.11
56 from typing_extensions import Self as Self # 3.11
57 from typing_extensions import TypeAliasType as TypeAliasType # 3.12
58 from typing_extensions import Unpack as Unpack # 3.11
59 from typing_extensions import Never as Never # 3.11
60 from typing_extensions import LiteralString as LiteralString # 3.11
61
62
63_T = TypeVar("_T", bound=Any)
64_KT = TypeVar("_KT")
65_KT_co = TypeVar("_KT_co", covariant=True)
66_KT_contra = TypeVar("_KT_contra", contravariant=True)
67_VT = TypeVar("_VT")
68_VT_co = TypeVar("_VT_co", covariant=True)
69
70TupleAny = Tuple[Any, ...]
71
72
73if compat.py310:
74 # why they took until py310 to put this in stdlib is beyond me,
75 # I've been wanting it since py27
76 from types import NoneType as NoneType
77else:
78 NoneType = type(None) # type: ignore
79
80
81def is_fwd_none(typ: Any) -> bool:
82 return isinstance(typ, ForwardRef) and typ.__forward_arg__ == "None"
83
84
85_AnnotationScanType = Union[
86 Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
87]
88
89
90class ArgsTypeProtocol(Protocol):
91 """protocol for types that have ``__args__``
92
93 there's no public interface for this AFAIK
94
95 """
96
97 __args__: Tuple[_AnnotationScanType, ...]
98
99
100class GenericProtocol(Protocol[_T]):
101 """protocol for generic types.
102
103 this since Python.typing _GenericAlias is private
104
105 """
106
107 __args__: Tuple[_AnnotationScanType, ...]
108 __origin__: Type[_T]
109
110 # Python's builtin _GenericAlias has this method, however builtins like
111 # list, dict, etc. do not, even though they have ``__origin__`` and
112 # ``__args__``
113 #
114 # def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
115 # ...
116
117
118# copied from TypeShed, required in order to implement
119# MutableMapping.update()
120class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
121 def keys(self) -> Iterable[_KT]: ...
122
123 def __getitem__(self, __k: _KT) -> _VT_co: ...
124
125
126# work around https://github.com/microsoft/pyright/issues/3025
127_LiteralStar = Literal["*"]
128
129
130def de_stringify_annotation(
131 cls: Type[Any],
132 annotation: _AnnotationScanType,
133 originating_module: str,
134 locals_: Mapping[str, Any],
135 *,
136 str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
137 include_generic: bool = False,
138 _already_seen: Optional[Set[Any]] = None,
139) -> Type[Any]:
140 """Resolve annotations that may be string based into real objects.
141
142 This is particularly important if a module defines "from __future__ import
143 annotations", as everything inside of __annotations__ is a string. We want
144 to at least have generic containers like ``Mapped``, ``Union``, ``List``,
145 etc.
146
147 """
148 # looked at typing.get_type_hints(), looked at pydantic. We need much
149 # less here, and we here try to not use any private typing internals
150 # or construct ForwardRef objects which is documented as something
151 # that should be avoided.
152
153 original_annotation = annotation
154
155 if is_fwd_ref(annotation):
156 annotation = annotation.__forward_arg__
157
158 if isinstance(annotation, str):
159 if str_cleanup_fn:
160 annotation = str_cleanup_fn(annotation, originating_module)
161
162 annotation = eval_expression(
163 annotation, originating_module, locals_=locals_, in_class=cls
164 )
165
166 if (
167 include_generic
168 and is_generic(annotation)
169 and not is_literal(annotation)
170 ):
171 if _already_seen is None:
172 _already_seen = set()
173
174 if annotation in _already_seen:
175 # only occurs recursively. outermost return type
176 # will always be Type.
177 # the element here will be either ForwardRef or
178 # Optional[ForwardRef]
179 return original_annotation # type: ignore
180 else:
181 _already_seen.add(annotation)
182
183 elements = tuple(
184 de_stringify_annotation(
185 cls,
186 elem,
187 originating_module,
188 locals_,
189 str_cleanup_fn=str_cleanup_fn,
190 include_generic=include_generic,
191 _already_seen=_already_seen,
192 )
193 for elem in annotation.__args__
194 )
195
196 return _copy_generic_annotation_with(annotation, elements)
197
198 return annotation # type: ignore
199
200
201def fixup_container_fwd_refs(
202 type_: _AnnotationScanType,
203) -> _AnnotationScanType:
204 """Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')]
205 and similar for list, set
206
207 """
208
209 if (
210 is_generic(type_)
211 and get_origin(type_)
212 in (
213 dict,
214 set,
215 list,
216 collections_abc.MutableSet,
217 collections_abc.MutableMapping,
218 collections_abc.MutableSequence,
219 collections_abc.Mapping,
220 collections_abc.Sequence,
221 )
222 # fight, kick and scream to struggle to tell the difference between
223 # dict[] and typing.Dict[] which DO NOT compare the same and DO NOT
224 # behave the same yet there is NO WAY to distinguish between which type
225 # it is using public attributes
226 and not re.match(
227 "typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_)
228 )
229 ):
230 # compat with py3.10 and earlier
231 return get_origin(type_).__class_getitem__( # type: ignore
232 tuple(
233 [
234 ForwardRef(elem) if isinstance(elem, str) else elem
235 for elem in get_args(type_)
236 ]
237 )
238 )
239 return type_
240
241
242def _copy_generic_annotation_with(
243 annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
244) -> Type[_T]:
245 if hasattr(annotation, "copy_with"):
246 # List, Dict, etc. real generics
247 return annotation.copy_with(elements) # type: ignore
248 else:
249 # Python builtins list, dict, etc.
250 return annotation.__origin__[elements] # type: ignore
251
252
253def eval_expression(
254 expression: str,
255 module_name: str,
256 *,
257 locals_: Optional[Mapping[str, Any]] = None,
258 in_class: Optional[Type[Any]] = None,
259) -> Any:
260 try:
261 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
262 except KeyError as ke:
263 raise NameError(
264 f"Module {module_name} isn't present in sys.modules; can't "
265 f"evaluate expression {expression}"
266 ) from ke
267
268 try:
269 if in_class is not None:
270 cls_namespace = dict(in_class.__dict__)
271 cls_namespace.setdefault(in_class.__name__, in_class)
272
273 # see #10899. We want the locals/globals to take precedence
274 # over the class namespace in this context, even though this
275 # is not the usual way variables would resolve.
276 cls_namespace.update(base_globals)
277
278 annotation = eval(expression, cls_namespace, locals_)
279 else:
280 annotation = eval(expression, base_globals, locals_)
281 except Exception as err:
282 raise NameError(
283 f"Could not de-stringify annotation {expression!r}"
284 ) from err
285 else:
286 return annotation
287
288
289def eval_name_only(
290 name: str,
291 module_name: str,
292 *,
293 locals_: Optional[Mapping[str, Any]] = None,
294) -> Any:
295 if "." in name:
296 return eval_expression(name, module_name, locals_=locals_)
297
298 try:
299 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
300 except KeyError as ke:
301 raise NameError(
302 f"Module {module_name} isn't present in sys.modules; can't "
303 f"resolve name {name}"
304 ) from ke
305
306 # name only, just look in globals. eval() works perfectly fine here,
307 # however we are seeking to have this be faster, as this occurs for
308 # every Mapper[] keyword, etc. depending on configuration
309 try:
310 return base_globals[name]
311 except KeyError as ke:
312 # check in builtins as well to handle `list`, `set` or `dict`, etc.
313 try:
314 return builtins.__dict__[name]
315 except KeyError:
316 pass
317
318 raise NameError(
319 f"Could not locate name {name} in module {module_name}"
320 ) from ke
321
322
323def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
324 try:
325 obj = eval_name_only(name, module_name)
326 except NameError:
327 return name
328 else:
329 return getattr(obj, "__name__", name)
330
331
332def is_pep593(type_: Optional[Any]) -> bool:
333 return type_ is not None and get_origin(type_) in _type_tuples.Annotated
334
335
336def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
337 return isinstance(obj, collections_abc.Iterable) and not isinstance(
338 obj, (str, bytes)
339 )
340
341
342def is_literal(type_: Any) -> bool:
343 return get_origin(type_) in _type_tuples.Literal
344
345
346def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
347 return hasattr(type_, "__supertype__")
348 # doesn't work in 3.9, 3.8, 3.7 as it passes a closure, not an
349 # object instance
350 # isinstance(type, type_instances.NewType)
351
352
353def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
354 return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
355
356
357def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
358 # NOTE: a generic TAT does not instance check as TypeAliasType outside of
359 # python 3.10. For sqlalchemy use cases it's fine to consider it a TAT
360 # though.
361 # NOTE: things seems to work also without this additional check
362 if is_generic(type_):
363 return is_pep695(type_.__origin__)
364 return isinstance(type_, _type_instances.TypeAliasType)
365
366
367def pep695_values(type_: _AnnotationScanType) -> Set[Any]:
368 """Extracts the value from a TypeAliasType, recursively exploring unions
369 and inner TypeAliasType to flatten them into a single set.
370
371 Forward references are not evaluated, so no recursive exploration happens
372 into them.
373 """
374 _seen = set()
375
376 def recursive_value(inner_type):
377 if inner_type in _seen:
378 # recursion are not supported (at least it's flagged as
379 # an error by pyright). Just avoid infinite loop
380 return inner_type
381 _seen.add(inner_type)
382 if not is_pep695(inner_type):
383 return inner_type
384 value = inner_type.__value__
385 if not is_union(value):
386 return value
387 return [recursive_value(t) for t in value.__args__]
388
389 res = recursive_value(type_)
390 if isinstance(res, list):
391 types = set()
392 stack = deque(res)
393 while stack:
394 t = stack.popleft()
395 if isinstance(t, list):
396 stack.extend(t)
397 else:
398 types.add(None if t is NoneType or is_fwd_none(t) else t)
399 return types
400 else:
401 return {res}
402
403
404def is_fwd_ref(
405 type_: _AnnotationScanType,
406 check_generic: bool = False,
407 check_for_plain_string: bool = False,
408) -> TypeGuard[ForwardRef]:
409 if check_for_plain_string and isinstance(type_, str):
410 return True
411 elif isinstance(type_, _type_instances.ForwardRef):
412 return True
413 elif check_generic and is_generic(type_):
414 return any(
415 is_fwd_ref(
416 arg, True, check_for_plain_string=check_for_plain_string
417 )
418 for arg in type_.__args__
419 )
420 else:
421 return False
422
423
424@overload
425def de_optionalize_union_types(type_: str) -> str: ...
426
427
428@overload
429def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ...
430
431
432@overload
433def de_optionalize_union_types(
434 type_: _AnnotationScanType,
435) -> _AnnotationScanType: ...
436
437
438def de_optionalize_union_types(
439 type_: _AnnotationScanType,
440) -> _AnnotationScanType:
441 """Given a type, filter out ``Union`` types that include ``NoneType``
442 to not include the ``NoneType``.
443
444 """
445
446 if is_fwd_ref(type_):
447 return _de_optionalize_fwd_ref_union_types(type_, False)
448
449 elif is_union(type_) and includes_none(type_):
450 typ = {
451 t
452 for t in type_.__args__
453 if t is not NoneType and not is_fwd_none(t)
454 }
455
456 return make_union_type(*typ)
457
458 else:
459 return type_
460
461
462@overload
463def _de_optionalize_fwd_ref_union_types(
464 type_: ForwardRef, return_has_none: Literal[True]
465) -> bool: ...
466
467
468@overload
469def _de_optionalize_fwd_ref_union_types(
470 type_: ForwardRef, return_has_none: Literal[False]
471) -> _AnnotationScanType: ...
472
473
474def _de_optionalize_fwd_ref_union_types(
475 type_: ForwardRef, return_has_none: bool
476) -> Union[_AnnotationScanType, bool]:
477 """return the non-optional type for Optional[], Union[None, ...], x|None,
478 etc. without de-stringifying forward refs.
479
480 unfortunately this seems to require lots of hardcoded heuristics
481
482 """
483
484 annotation = type_.__forward_arg__
485
486 mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
487 if mm:
488 g1 = mm.group(1).split(".")[-1]
489 if g1 == "Optional":
490 return True if return_has_none else ForwardRef(mm.group(2))
491 elif g1 == "Union":
492 if "[" in mm.group(2):
493 # cases like "Union[Dict[str, int], int, None]"
494 elements: list[str] = []
495 current: list[str] = []
496 ignore_comma = 0
497 for char in mm.group(2):
498 if char == "[":
499 ignore_comma += 1
500 elif char == "]":
501 ignore_comma -= 1
502 elif ignore_comma == 0 and char == ",":
503 elements.append("".join(current).strip())
504 current.clear()
505 continue
506 current.append(char)
507 else:
508 elements = re.split(r",\s*", mm.group(2))
509 parts = [ForwardRef(elem) for elem in elements if elem != "None"]
510 if return_has_none:
511 return len(elements) != len(parts)
512 else:
513 return make_union_type(*parts) if parts else Never # type: ignore[return-value] # noqa: E501
514 else:
515 return False if return_has_none else type_
516
517 pipe_tokens = re.split(r"\s*\|\s*", annotation)
518 has_none = "None" in pipe_tokens
519 if return_has_none:
520 return has_none
521 if has_none:
522 anno_str = "|".join(p for p in pipe_tokens if p != "None")
523 return ForwardRef(anno_str) if anno_str else Never # type: ignore[return-value] # noqa: E501
524
525 return type_
526
527
528def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
529 """Make a Union type."""
530
531 return Union[types] # type: ignore
532
533
534def includes_none(type_: Any) -> bool:
535 """Returns if the type annotation ``type_`` allows ``None``.
536
537 This function supports:
538 * forward refs
539 * unions
540 * pep593 - Annotated
541 * pep695 - TypeAliasType (does not support looking into
542 fw reference of other pep695)
543 * NewType
544 * plain types like ``int``, ``None``, etc
545 """
546 if is_fwd_ref(type_):
547 return _de_optionalize_fwd_ref_union_types(type_, True)
548 if is_union(type_):
549 return any(includes_none(t) for t in get_args(type_))
550 if is_pep593(type_):
551 return includes_none(get_args(type_)[0])
552 if is_pep695(type_):
553 return any(includes_none(t) for t in pep695_values(type_))
554 if is_newtype(type_):
555 return includes_none(type_.__supertype__)
556 try:
557 return type_ in (NoneType, None) or is_fwd_none(type_)
558 except TypeError:
559 # if type_ is Column, mapped_column(), etc. the use of "in"
560 # resolves to ``__eq__()`` which then gives us an expression object
561 # that can't resolve to boolean. just catch it all via exception
562 return False
563
564
565def is_a_type(type_: Any) -> bool:
566 return (
567 isinstance(type_, type)
568 or hasattr(type_, "__origin__")
569 or type_.__module__ in ("typing", "typing_extensions")
570 or type(type_).__mro__[0].__module__ in ("typing", "typing_extensions")
571 )
572
573
574def is_union(type_: Any) -> TypeGuard[ArgsTypeProtocol]:
575 return is_origin_of(type_, "Union", "UnionType")
576
577
578def is_origin_of_cls(
579 type_: Any, class_obj: Union[Tuple[Type[Any], ...], Type[Any]]
580) -> bool:
581 """return True if the given type has an __origin__ that shares a base
582 with the given class"""
583
584 origin = get_origin(type_)
585 if origin is None:
586 return False
587
588 return isinstance(origin, type) and issubclass(origin, class_obj)
589
590
591def is_origin_of(
592 type_: Any, *names: str, module: Optional[str] = None
593) -> bool:
594 """return True if the given type has an __origin__ with the given name
595 and optional module."""
596
597 origin = get_origin(type_)
598 if origin is None:
599 return False
600
601 return _get_type_name(origin) in names and (
602 module is None or origin.__module__.startswith(module)
603 )
604
605
606def _get_type_name(type_: Type[Any]) -> str:
607 if compat.py310:
608 return type_.__name__
609 else:
610 typ_name = getattr(type_, "__name__", None)
611 if typ_name is None:
612 typ_name = getattr(type_, "_name", None)
613
614 return typ_name # type: ignore
615
616
617class DescriptorProto(Protocol):
618 def __get__(self, instance: object, owner: Any) -> Any: ...
619
620 def __set__(self, instance: Any, value: Any) -> None: ...
621
622 def __delete__(self, instance: Any) -> None: ...
623
624
625_DESC = TypeVar("_DESC", bound=DescriptorProto)
626
627
628class DescriptorReference(Generic[_DESC]):
629 """a descriptor that refers to a descriptor.
630
631 used for cases where we need to have an instance variable referring to an
632 object that is itself a descriptor, which typically confuses typing tools
633 as they don't know when they should use ``__get__`` or not when referring
634 to the descriptor assignment as an instance variable. See
635 sqlalchemy.orm.interfaces.PropComparator.prop
636
637 """
638
639 if TYPE_CHECKING:
640
641 def __get__(self, instance: object, owner: Any) -> _DESC: ...
642
643 def __set__(self, instance: Any, value: _DESC) -> None: ...
644
645 def __delete__(self, instance: Any) -> None: ...
646
647
648_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
649
650
651class RODescriptorReference(Generic[_DESC_co]):
652 """a descriptor that refers to a descriptor.
653
654 same as :class:`.DescriptorReference` but is read-only, so that subclasses
655 can define a subtype as the generically contained element
656
657 """
658
659 if TYPE_CHECKING:
660
661 def __get__(self, instance: object, owner: Any) -> _DESC_co: ...
662
663 def __set__(self, instance: Any, value: Any) -> NoReturn: ...
664
665 def __delete__(self, instance: Any) -> NoReturn: ...
666
667
668_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
669
670
671class CallableReference(Generic[_FN]):
672 """a descriptor that refers to a callable.
673
674 works around mypy's limitation of not allowing callables assigned
675 as instance variables
676
677
678 """
679
680 if TYPE_CHECKING:
681
682 def __get__(self, instance: object, owner: Any) -> _FN: ...
683
684 def __set__(self, instance: Any, value: _FN) -> None: ...
685
686 def __delete__(self, instance: Any) -> None: ...
687
688
689class _TypingInstances:
690 def __getattr__(self, key: str) -> tuple[type, ...]:
691 types = tuple(
692 {
693 t
694 for t in [
695 getattr(typing, key, None),
696 getattr(typing_extensions, key, None),
697 ]
698 if t is not None
699 }
700 )
701 if not types:
702 raise AttributeError(key)
703 self.__dict__[key] = types
704 return types
705
706
707_type_tuples = _TypingInstances()
708if TYPE_CHECKING:
709 _type_instances = typing_extensions
710else:
711 _type_instances = _type_tuples
712
713LITERAL_TYPES = _type_tuples.Literal