1# util/typing.py
2# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import builtins
12from collections import deque
13import collections.abc as collections_abc
14import re
15import sys
16import typing
17from typing import Any
18from typing import Callable
19from typing import cast
20from typing import Dict
21from typing import ForwardRef
22from typing import Generic
23from typing import Iterable
24from typing import Mapping
25from typing import NewType
26from typing import NoReturn
27from typing import Optional
28from typing import overload
29from typing import Set
30from typing import Tuple
31from typing import Type
32from typing import TYPE_CHECKING
33from typing import TypeVar
34from typing import Union
35
36from . import compat
37
38if True: # zimports removes the tailing comments
39 from typing_extensions import Annotated as Annotated # 3.8
40 from typing_extensions import Concatenate as Concatenate # 3.10
41 from typing_extensions import (
42 dataclass_transform as dataclass_transform, # 3.11,
43 )
44 from typing_extensions import Final as Final # 3.8
45 from typing_extensions import final as final # 3.8
46 from typing_extensions import get_args as get_args # 3.10
47 from typing_extensions import get_origin as get_origin # 3.10
48 from typing_extensions import Literal as Literal # 3.8
49 from typing_extensions import NotRequired as NotRequired # 3.11
50 from typing_extensions import ParamSpec as ParamSpec # 3.10
51 from typing_extensions import Protocol as Protocol # 3.8
52 from typing_extensions import SupportsIndex as SupportsIndex # 3.8
53 from typing_extensions import TypeAlias as TypeAlias # 3.10
54 from typing_extensions import TypedDict as TypedDict # 3.8
55 from typing_extensions import TypeGuard as TypeGuard # 3.10
56 from typing_extensions import Self as Self # 3.11
57 from typing_extensions import TypeAliasType as TypeAliasType # 3.12
58 from typing_extensions import Never as Never # 3.11
59
60_T = TypeVar("_T", bound=Any)
61_KT = TypeVar("_KT")
62_KT_co = TypeVar("_KT_co", covariant=True)
63_KT_contra = TypeVar("_KT_contra", contravariant=True)
64_VT = TypeVar("_VT")
65_VT_co = TypeVar("_VT_co", covariant=True)
66
67if compat.py38:
68 # typing_extensions.Literal is different from typing.Literal until
69 # Python 3.10.1
70 LITERAL_TYPES = frozenset([typing.Literal, Literal])
71else:
72 LITERAL_TYPES = frozenset([Literal])
73
74
75if compat.py310:
76 # why they took until py310 to put this in stdlib is beyond me,
77 # I've been wanting it since py27
78 from types import NoneType as NoneType
79else:
80 NoneType = type(None) # type: ignore
81
82NoneFwd = ForwardRef("None")
83
84
85_AnnotationScanType = Union[
86 Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
87]
88
89
90class ArgsTypeProtocol(Protocol):
91 """protocol for types that have ``__args__``
92
93 there's no public interface for this AFAIK
94
95 """
96
97 __args__: Tuple[_AnnotationScanType, ...]
98
99
100class GenericProtocol(Protocol[_T]):
101 """protocol for generic types.
102
103 this since Python.typing _GenericAlias is private
104
105 """
106
107 __args__: Tuple[_AnnotationScanType, ...]
108 __origin__: Type[_T]
109
110 # Python's builtin _GenericAlias has this method, however builtins like
111 # list, dict, etc. do not, even though they have ``__origin__`` and
112 # ``__args__``
113 #
114 # def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
115 # ...
116
117
118# copied from TypeShed, required in order to implement
119# MutableMapping.update()
120class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
121 def keys(self) -> Iterable[_KT]: ...
122
123 def __getitem__(self, __k: _KT) -> _VT_co: ...
124
125
126# work around https://github.com/microsoft/pyright/issues/3025
127_LiteralStar = Literal["*"]
128
129
130def de_stringify_annotation(
131 cls: Type[Any],
132 annotation: _AnnotationScanType,
133 originating_module: str,
134 locals_: Mapping[str, Any],
135 *,
136 str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
137 include_generic: bool = False,
138 _already_seen: Optional[Set[Any]] = None,
139) -> Type[Any]:
140 """Resolve annotations that may be string based into real objects.
141
142 This is particularly important if a module defines "from __future__ import
143 annotations", as everything inside of __annotations__ is a string. We want
144 to at least have generic containers like ``Mapped``, ``Union``, ``List``,
145 etc.
146
147 """
148 # looked at typing.get_type_hints(), looked at pydantic. We need much
149 # less here, and we here try to not use any private typing internals
150 # or construct ForwardRef objects which is documented as something
151 # that should be avoided.
152
153 original_annotation = annotation
154
155 if is_fwd_ref(annotation):
156 annotation = annotation.__forward_arg__
157
158 if isinstance(annotation, str):
159 if str_cleanup_fn:
160 annotation = str_cleanup_fn(annotation, originating_module)
161
162 annotation = eval_expression(
163 annotation, originating_module, locals_=locals_, in_class=cls
164 )
165
166 if (
167 include_generic
168 and is_generic(annotation)
169 and not is_literal(annotation)
170 ):
171 if _already_seen is None:
172 _already_seen = set()
173
174 if annotation in _already_seen:
175 # only occurs recursively. outermost return type
176 # will always be Type.
177 # the element here will be either ForwardRef or
178 # Optional[ForwardRef]
179 return original_annotation # type: ignore
180 else:
181 _already_seen.add(annotation)
182
183 elements = tuple(
184 de_stringify_annotation(
185 cls,
186 elem,
187 originating_module,
188 locals_,
189 str_cleanup_fn=str_cleanup_fn,
190 include_generic=include_generic,
191 _already_seen=_already_seen,
192 )
193 for elem in annotation.__args__
194 )
195
196 return _copy_generic_annotation_with(annotation, elements)
197
198 return annotation # type: ignore
199
200
201def fixup_container_fwd_refs(
202 type_: _AnnotationScanType,
203) -> _AnnotationScanType:
204 """Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')]
205 and similar for list, set
206
207 """
208
209 if (
210 is_generic(type_)
211 and get_origin(type_)
212 in (
213 dict,
214 set,
215 list,
216 collections_abc.MutableSet,
217 collections_abc.MutableMapping,
218 collections_abc.MutableSequence,
219 collections_abc.Mapping,
220 collections_abc.Sequence,
221 )
222 # fight, kick and scream to struggle to tell the difference between
223 # dict[] and typing.Dict[] which DO NOT compare the same and DO NOT
224 # behave the same yet there is NO WAY to distinguish between which type
225 # it is using public attributes
226 and not re.match(
227 "typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_)
228 )
229 ):
230 # compat with py3.10 and earlier
231 return get_origin(type_).__class_getitem__( # type: ignore
232 tuple(
233 [
234 ForwardRef(elem) if isinstance(elem, str) else elem
235 for elem in get_args(type_)
236 ]
237 )
238 )
239 return type_
240
241
242def _copy_generic_annotation_with(
243 annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
244) -> Type[_T]:
245 if hasattr(annotation, "copy_with"):
246 # List, Dict, etc. real generics
247 return annotation.copy_with(elements) # type: ignore
248 else:
249 # Python builtins list, dict, etc.
250 return annotation.__origin__[elements] # type: ignore
251
252
253def eval_expression(
254 expression: str,
255 module_name: str,
256 *,
257 locals_: Optional[Mapping[str, Any]] = None,
258 in_class: Optional[Type[Any]] = None,
259) -> Any:
260 try:
261 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
262 except KeyError as ke:
263 raise NameError(
264 f"Module {module_name} isn't present in sys.modules; can't "
265 f"evaluate expression {expression}"
266 ) from ke
267
268 try:
269 if in_class is not None:
270 cls_namespace = dict(in_class.__dict__)
271 cls_namespace.setdefault(in_class.__name__, in_class)
272
273 # see #10899. We want the locals/globals to take precedence
274 # over the class namespace in this context, even though this
275 # is not the usual way variables would resolve.
276 cls_namespace.update(base_globals)
277
278 annotation = eval(expression, cls_namespace, locals_)
279 else:
280 annotation = eval(expression, base_globals, locals_)
281 except Exception as err:
282 raise NameError(
283 f"Could not de-stringify annotation {expression!r}"
284 ) from err
285 else:
286 return annotation
287
288
289def eval_name_only(
290 name: str,
291 module_name: str,
292 *,
293 locals_: Optional[Mapping[str, Any]] = None,
294) -> Any:
295 if "." in name:
296 return eval_expression(name, module_name, locals_=locals_)
297
298 try:
299 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
300 except KeyError as ke:
301 raise NameError(
302 f"Module {module_name} isn't present in sys.modules; can't "
303 f"resolve name {name}"
304 ) from ke
305
306 # name only, just look in globals. eval() works perfectly fine here,
307 # however we are seeking to have this be faster, as this occurs for
308 # every Mapper[] keyword, etc. depending on configuration
309 try:
310 return base_globals[name]
311 except KeyError as ke:
312 # check in builtins as well to handle `list`, `set` or `dict`, etc.
313 try:
314 return builtins.__dict__[name]
315 except KeyError:
316 pass
317
318 raise NameError(
319 f"Could not locate name {name} in module {module_name}"
320 ) from ke
321
322
323def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
324 try:
325 obj = eval_name_only(name, module_name)
326 except NameError:
327 return name
328 else:
329 return getattr(obj, "__name__", name)
330
331
332def is_pep593(type_: Optional[Any]) -> bool:
333 return type_ is not None and get_origin(type_) is Annotated
334
335
336def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
337 return isinstance(obj, collections_abc.Iterable) and not isinstance(
338 obj, (str, bytes)
339 )
340
341
342def is_literal(type_: Any) -> bool:
343 return get_origin(type_) in LITERAL_TYPES
344
345
346def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
347 return hasattr(type_, "__supertype__")
348
349 # doesn't work in 3.8, 3.7 as it passes a closure, not an
350 # object instance
351 # return isinstance(type_, NewType)
352
353
354def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
355 return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
356
357
358def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
359 return isinstance(type_, TypeAliasType)
360
361
362def flatten_newtype(type_: NewType) -> Type[Any]:
363 super_type = type_.__supertype__
364 while is_newtype(super_type):
365 super_type = super_type.__supertype__
366 return super_type # type: ignore[return-value]
367
368
369def pep695_values(type_: _AnnotationScanType) -> Set[Any]:
370 """Extracts the value from a TypeAliasType, recursively exploring unions
371 and inner TypeAliasType to flatten them into a single set.
372
373 Forward references are not evaluated, so no recursive exploration happens
374 into them.
375 """
376 _seen = set()
377
378 def recursive_value(type_):
379 if type_ in _seen:
380 # recursion are not supported (at least it's flagged as
381 # an error by pyright). Just avoid infinite loop
382 return type_
383 _seen.add(type_)
384 if not is_pep695(type_):
385 return type_
386 value = type_.__value__
387 if not is_union(value):
388 return value
389 return [recursive_value(t) for t in value.__args__]
390
391 res = recursive_value(type_)
392 if isinstance(res, list):
393 types = set()
394 stack = deque(res)
395 while stack:
396 t = stack.popleft()
397 if isinstance(t, list):
398 stack.extend(t)
399 else:
400 types.add(None if t in {NoneType, NoneFwd} else t)
401 return types
402 else:
403 return {res}
404
405
406def is_fwd_ref(
407 type_: _AnnotationScanType,
408 check_generic: bool = False,
409 check_for_plain_string: bool = False,
410) -> TypeGuard[ForwardRef]:
411 if check_for_plain_string and isinstance(type_, str):
412 return True
413 elif isinstance(type_, ForwardRef):
414 return True
415 elif check_generic and is_generic(type_):
416 return any(
417 is_fwd_ref(
418 arg, True, check_for_plain_string=check_for_plain_string
419 )
420 for arg in type_.__args__
421 )
422 else:
423 return False
424
425
426@overload
427def de_optionalize_union_types(type_: str) -> str: ...
428
429
430@overload
431def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ...
432
433
434@overload
435def de_optionalize_union_types(
436 type_: _AnnotationScanType,
437) -> _AnnotationScanType: ...
438
439
440def de_optionalize_union_types(
441 type_: _AnnotationScanType,
442) -> _AnnotationScanType:
443 """Given a type, filter out ``Union`` types that include ``NoneType``
444 to not include the ``NoneType``.
445
446 Contains extra logic to work on non-flattened unions, unions that contain
447 ``None`` (seen in py38, 37)
448
449 """
450
451 if is_fwd_ref(type_):
452 return _de_optionalize_fwd_ref_union_types(type_, False)
453
454 elif is_union(type_) and includes_none(type_):
455 if compat.py39:
456 typ = set(type_.__args__)
457 else:
458 # py38, 37 - unions are not automatically flattened, can contain
459 # None rather than NoneType
460 stack_of_unions = deque([type_])
461 typ = set()
462 while stack_of_unions:
463 u_typ = stack_of_unions.popleft()
464 for elem in u_typ.__args__:
465 if is_union(elem):
466 stack_of_unions.append(elem)
467 else:
468 typ.add(elem)
469
470 typ.discard(None) # type: ignore
471
472 typ.discard(NoneType)
473 typ.discard(NoneFwd)
474
475 return make_union_type(*typ)
476
477 else:
478 return type_
479
480
481@overload
482def _de_optionalize_fwd_ref_union_types(
483 type_: ForwardRef, return_has_none: Literal[True]
484) -> bool: ...
485
486
487@overload
488def _de_optionalize_fwd_ref_union_types(
489 type_: ForwardRef, return_has_none: Literal[False]
490) -> _AnnotationScanType: ...
491
492
493def _de_optionalize_fwd_ref_union_types(
494 type_: ForwardRef, return_has_none: bool
495) -> Union[_AnnotationScanType, bool]:
496 """return the non-optional type for Optional[], Union[None, ...], x|None,
497 etc. without de-stringifying forward refs.
498
499 unfortunately this seems to require lots of hardcoded heuristics
500
501 """
502
503 annotation = type_.__forward_arg__
504
505 mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
506 if mm:
507 g1 = mm.group(1).split(".")[-1]
508 if g1 == "Optional":
509 return True if return_has_none else ForwardRef(mm.group(2))
510 elif g1 == "Union":
511 if "[" in mm.group(2):
512 # cases like "Union[Dict[str, int], int, None]"
513 elements: list[str] = []
514 current: list[str] = []
515 ignore_comma = 0
516 for char in mm.group(2):
517 if char == "[":
518 ignore_comma += 1
519 elif char == "]":
520 ignore_comma -= 1
521 elif ignore_comma == 0 and char == ",":
522 elements.append("".join(current).strip())
523 current.clear()
524 continue
525 current.append(char)
526 else:
527 elements = re.split(r",\s*", mm.group(2))
528 parts = [ForwardRef(elem) for elem in elements if elem != "None"]
529 if return_has_none:
530 return len(elements) != len(parts)
531 else:
532 return make_union_type(*parts) if parts else Never # type: ignore[return-value] # noqa: E501
533 else:
534 return False if return_has_none else type_
535
536 pipe_tokens = re.split(r"\s*\|\s*", annotation)
537 has_none = "None" in pipe_tokens
538 if return_has_none:
539 return has_none
540 if has_none:
541 anno_str = "|".join(p for p in pipe_tokens if p != "None")
542 return ForwardRef(anno_str) if anno_str else Never # type: ignore[return-value] # noqa: E501
543
544 return type_
545
546
547def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
548 """Make a Union type."""
549 return Union.__getitem__(types) # type: ignore
550
551
552def includes_none(type_: Any) -> bool:
553 """Returns if the type annotation ``type_`` allows ``None``.
554
555 This function supports:
556 * forward refs
557 * unions
558 * pep593 - Annotated
559 * pep695 - TypeAliasType (does not support looking into
560 fw reference of other pep695)
561 * NewType
562 * plain types like ``int``, ``None``, etc
563 """
564 if is_fwd_ref(type_):
565 return _de_optionalize_fwd_ref_union_types(type_, True)
566 if is_union(type_):
567 return any(includes_none(t) for t in get_args(type_))
568 if is_pep593(type_):
569 return includes_none(get_args(type_)[0])
570 if is_pep695(type_):
571 return any(includes_none(t) for t in pep695_values(type_))
572 if is_newtype(type_):
573 return includes_none(type_.__supertype__)
574 return type_ in (NoneFwd, NoneType, None)
575
576
577def is_union(type_: Any) -> TypeGuard[ArgsTypeProtocol]:
578 return is_origin_of(type_, "Union", "UnionType")
579
580
581def is_origin_of_cls(
582 type_: Any, class_obj: Union[Tuple[Type[Any], ...], Type[Any]]
583) -> bool:
584 """return True if the given type has an __origin__ that shares a base
585 with the given class"""
586
587 origin = get_origin(type_)
588 if origin is None:
589 return False
590
591 return isinstance(origin, type) and issubclass(origin, class_obj)
592
593
594def is_origin_of(
595 type_: Any, *names: str, module: Optional[str] = None
596) -> bool:
597 """return True if the given type has an __origin__ with the given name
598 and optional module."""
599
600 origin = get_origin(type_)
601 if origin is None:
602 return False
603
604 return _get_type_name(origin) in names and (
605 module is None or origin.__module__.startswith(module)
606 )
607
608
609def _get_type_name(type_: Type[Any]) -> str:
610 if compat.py310:
611 return type_.__name__
612 else:
613 typ_name = getattr(type_, "__name__", None)
614 if typ_name is None:
615 typ_name = getattr(type_, "_name", None)
616
617 return typ_name # type: ignore
618
619
620class DescriptorProto(Protocol):
621 def __get__(self, instance: object, owner: Any) -> Any: ...
622
623 def __set__(self, instance: Any, value: Any) -> None: ...
624
625 def __delete__(self, instance: Any) -> None: ...
626
627
628_DESC = TypeVar("_DESC", bound=DescriptorProto)
629
630
631class DescriptorReference(Generic[_DESC]):
632 """a descriptor that refers to a descriptor.
633
634 used for cases where we need to have an instance variable referring to an
635 object that is itself a descriptor, which typically confuses typing tools
636 as they don't know when they should use ``__get__`` or not when referring
637 to the descriptor assignment as an instance variable. See
638 sqlalchemy.orm.interfaces.PropComparator.prop
639
640 """
641
642 if TYPE_CHECKING:
643
644 def __get__(self, instance: object, owner: Any) -> _DESC: ...
645
646 def __set__(self, instance: Any, value: _DESC) -> None: ...
647
648 def __delete__(self, instance: Any) -> None: ...
649
650
651_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
652
653
654class RODescriptorReference(Generic[_DESC_co]):
655 """a descriptor that refers to a descriptor.
656
657 same as :class:`.DescriptorReference` but is read-only, so that subclasses
658 can define a subtype as the generically contained element
659
660 """
661
662 if TYPE_CHECKING:
663
664 def __get__(self, instance: object, owner: Any) -> _DESC_co: ...
665
666 def __set__(self, instance: Any, value: Any) -> NoReturn: ...
667
668 def __delete__(self, instance: Any) -> NoReturn: ...
669
670
671_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
672
673
674class CallableReference(Generic[_FN]):
675 """a descriptor that refers to a callable.
676
677 works around mypy's limitation of not allowing callables assigned
678 as instance variables
679
680
681 """
682
683 if TYPE_CHECKING:
684
685 def __get__(self, instance: object, owner: Any) -> _FN: ...
686
687 def __set__(self, instance: Any, value: _FN) -> None: ...
688
689 def __delete__(self, instance: Any) -> None: ...