1# util/typing.py
2# Copyright (C) 2022-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import builtins
12from collections import deque
13import collections.abc as collections_abc
14import re
15import sys
16from types import NoneType
17import typing
18from typing import Any
19from typing import Callable
20from typing import Dict
21from typing import ForwardRef
22from typing import Generic
23from typing import get_args
24from typing import get_origin
25from typing import Iterable
26from typing import Literal
27from typing import Mapping
28from typing import NewType
29from typing import NoReturn
30from typing import Optional
31from typing import overload
32from typing import Protocol
33from typing import Set
34from typing import Tuple
35from typing import Type
36from typing import TYPE_CHECKING
37from typing import TypeGuard
38from typing import Union
39
40import typing_extensions
41
42if True: # zimports removes the tailing comments
43 from typing_extensions import (
44 dataclass_transform as dataclass_transform, # 3.11,
45 )
46 from typing_extensions import NotRequired as NotRequired # 3.11
47 from typing_extensions import TypeVarTuple as TypeVarTuple # 3.11
48 from typing_extensions import Self as Self # 3.11
49 from typing_extensions import TypeAliasType as TypeAliasType # 3.12
50 from typing_extensions import Unpack as Unpack # 3.11
51 from typing_extensions import Never as Never # 3.11
52 from typing_extensions import LiteralString as LiteralString # 3.11
53 from typing_extensions import TypeVar as TypeVar # 3.13 for default
54
55
56_T = TypeVar("_T", bound=Any)
57_KT = TypeVar("_KT")
58_VT = TypeVar("_VT")
59_VT_co = TypeVar("_VT_co", covariant=True)
60
61TupleAny = Tuple[Any, ...]
62
63
64def is_fwd_none(typ: Any) -> bool:
65 return isinstance(typ, ForwardRef) and typ.__forward_arg__ == "None"
66
67
68_AnnotationScanType = Union[
69 Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
70]
71
72_MatchedOnType = Union[
73 "GenericProtocol[Any]", TypeAliasType, NewType, Type[Any]
74]
75
76
77class ArgsTypeProtocol(Protocol):
78 """protocol for types that have ``__args__``
79
80 there's no public interface for this AFAIK
81
82 """
83
84 __args__: Tuple[_AnnotationScanType, ...]
85
86
87class GenericProtocol(Protocol[_T]):
88 """protocol for generic types.
89
90 this since Python.typing _GenericAlias is private
91
92 """
93
94 __args__: Tuple[_AnnotationScanType, ...]
95 __origin__: Type[_T]
96
97 # Python's builtin _GenericAlias has this method, however builtins like
98 # list, dict, etc. do not, even though they have ``__origin__`` and
99 # ``__args__``
100 #
101 # def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
102 # ...
103
104
105# copied from TypeShed, required in order to implement
106# MutableMapping.update()
107class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
108 def keys(self) -> Iterable[_KT]: ...
109
110 def __getitem__(self, __k: _KT) -> _VT_co: ...
111
112
113# work around https://github.com/microsoft/pyright/issues/3025
114_LiteralStar = Literal["*"]
115
116
117def de_stringify_annotation(
118 cls: Type[Any],
119 annotation: _AnnotationScanType,
120 originating_module: str,
121 locals_: Mapping[str, Any],
122 *,
123 str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
124 include_generic: bool = False,
125 _already_seen: Optional[Set[Any]] = None,
126) -> Type[Any]:
127 """Resolve annotations that may be string based into real objects.
128
129 This is particularly important if a module defines "from __future__ import
130 annotations", as everything inside of __annotations__ is a string. We want
131 to at least have generic containers like ``Mapped``, ``Union``, ``List``,
132 etc.
133
134 """
135 # looked at typing.get_type_hints(), looked at pydantic. We need much
136 # less here, and we here try to not use any private typing internals
137 # or construct ForwardRef objects which is documented as something
138 # that should be avoided.
139
140 original_annotation = annotation
141
142 if is_fwd_ref(annotation):
143 annotation = annotation.__forward_arg__
144
145 if isinstance(annotation, str):
146 if str_cleanup_fn:
147 annotation = str_cleanup_fn(annotation, originating_module)
148
149 annotation = eval_expression(
150 annotation, originating_module, locals_=locals_, in_class=cls
151 )
152
153 if (
154 include_generic
155 and is_generic(annotation)
156 and not is_literal(annotation)
157 ):
158 if _already_seen is None:
159 _already_seen = set()
160
161 if annotation in _already_seen:
162 # only occurs recursively. outermost return type
163 # will always be Type.
164 # the element here will be either ForwardRef or
165 # Optional[ForwardRef]
166 return original_annotation # type: ignore
167 else:
168 _already_seen.add(annotation)
169
170 elements = tuple(
171 de_stringify_annotation(
172 cls,
173 elem,
174 originating_module,
175 locals_,
176 str_cleanup_fn=str_cleanup_fn,
177 include_generic=include_generic,
178 _already_seen=_already_seen,
179 )
180 for elem in annotation.__args__
181 )
182
183 return _copy_generic_annotation_with(annotation, elements)
184
185 return annotation # type: ignore
186
187
188def fixup_container_fwd_refs(
189 type_: _AnnotationScanType,
190) -> _AnnotationScanType:
191 """Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')]
192 and similar for list, set
193
194 """
195
196 if (
197 is_generic(type_)
198 and get_origin(type_)
199 in (
200 dict,
201 set,
202 list,
203 collections_abc.MutableSet,
204 collections_abc.MutableMapping,
205 collections_abc.MutableSequence,
206 collections_abc.Mapping,
207 collections_abc.Sequence,
208 )
209 # fight, kick and scream to struggle to tell the difference between
210 # dict[] and typing.Dict[] which DO NOT compare the same and DO NOT
211 # behave the same yet there is NO WAY to distinguish between which type
212 # it is using public attributes
213 and not re.match(
214 "typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_)
215 )
216 ):
217 # compat with py3.10 and earlier
218 return get_origin(type_).__class_getitem__( # type: ignore
219 tuple(
220 [
221 ForwardRef(elem) if isinstance(elem, str) else elem
222 for elem in get_args(type_)
223 ]
224 )
225 )
226 return type_
227
228
229def _copy_generic_annotation_with(
230 annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
231) -> Type[_T]:
232 if hasattr(annotation, "copy_with"):
233 # List, Dict, etc. real generics
234 return annotation.copy_with(elements) # type: ignore
235 else:
236 # Python builtins list, dict, etc.
237 return annotation.__origin__[elements] # type: ignore
238
239
240def eval_expression(
241 expression: str,
242 module_name: str,
243 *,
244 locals_: Optional[Mapping[str, Any]] = None,
245 in_class: Optional[Type[Any]] = None,
246) -> Any:
247 try:
248 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
249 except KeyError as ke:
250 raise NameError(
251 f"Module {module_name} isn't present in sys.modules; can't "
252 f"evaluate expression {expression}"
253 ) from ke
254
255 try:
256 if in_class is not None:
257 cls_namespace = dict(in_class.__dict__)
258 cls_namespace.setdefault(in_class.__name__, in_class)
259
260 # see #10899. We want the locals/globals to take precedence
261 # over the class namespace in this context, even though this
262 # is not the usual way variables would resolve.
263 cls_namespace.update(base_globals)
264
265 annotation = eval(expression, cls_namespace, locals_)
266 else:
267 annotation = eval(expression, base_globals, locals_)
268 except Exception as err:
269 raise NameError(
270 f"Could not de-stringify annotation {expression!r}"
271 ) from err
272 else:
273 return annotation
274
275
276def eval_name_only(
277 name: str,
278 module_name: str,
279 *,
280 locals_: Optional[Mapping[str, Any]] = None,
281) -> Any:
282 if "." in name:
283 return eval_expression(name, module_name, locals_=locals_)
284
285 try:
286 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
287 except KeyError as ke:
288 raise NameError(
289 f"Module {module_name} isn't present in sys.modules; can't "
290 f"resolve name {name}"
291 ) from ke
292
293 # name only, just look in globals. eval() works perfectly fine here,
294 # however we are seeking to have this be faster, as this occurs for
295 # every Mapper[] keyword, etc. depending on configuration
296 try:
297 return base_globals[name]
298 except KeyError as ke:
299 # check in builtins as well to handle `list`, `set` or `dict`, etc.
300 try:
301 return builtins.__dict__[name]
302 except KeyError:
303 pass
304
305 raise NameError(
306 f"Could not locate name {name} in module {module_name}"
307 ) from ke
308
309
310def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
311 try:
312 obj = eval_name_only(name, module_name)
313 except NameError:
314 return name
315 else:
316 return getattr(obj, "__name__", name)
317
318
319def is_pep593(type_: Optional[Any]) -> bool:
320 return type_ is not None and get_origin(type_) in _type_tuples.Annotated
321
322
323def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
324 return isinstance(obj, collections_abc.Iterable) and not isinstance(
325 obj, (str, bytes)
326 )
327
328
329def is_literal(type_: Any) -> bool:
330 return get_origin(type_) in _type_tuples.Literal
331
332
333def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
334 return isinstance(type_, _type_tuples.NewType)
335
336
337def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
338 return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
339
340
341def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
342 # NOTE: a generic TAT does not instance check as TypeAliasType outside of
343 # python 3.10. For sqlalchemy use cases it's fine to consider it a TAT
344 # though.
345 # NOTE: things seems to work also without this additional check
346 if is_generic(type_):
347 return is_pep695(type_.__origin__)
348 return isinstance(type_, _type_instances.TypeAliasType)
349
350
351def pep695_values(type_: _AnnotationScanType) -> Set[Any]:
352 """Extracts the value from a TypeAliasType, recursively exploring unions
353 and inner TypeAliasType to flatten them into a single set.
354
355 Forward references are not evaluated, so no recursive exploration happens
356 into them.
357 """
358 _seen = set()
359
360 def recursive_value(inner_type):
361 if inner_type in _seen:
362 # recursion are not supported (at least it's flagged as
363 # an error by pyright). Just avoid infinite loop
364 return inner_type
365 _seen.add(inner_type)
366 if not is_pep695(inner_type):
367 return inner_type
368 value = inner_type.__value__
369 if not is_union(value):
370 return value
371 return [recursive_value(t) for t in value.__args__]
372
373 res = recursive_value(type_)
374 if isinstance(res, list):
375 types = set()
376 stack = deque(res)
377 while stack:
378 t = stack.popleft()
379 if isinstance(t, list):
380 stack.extend(t)
381 else:
382 types.add(None if t is NoneType or is_fwd_none(t) else t)
383 return types
384 else:
385 return {res}
386
387
388@overload
389def is_fwd_ref(
390 type_: _AnnotationScanType,
391 check_generic: bool = ...,
392 check_for_plain_string: Literal[False] = ...,
393) -> TypeGuard[ForwardRef]: ...
394
395
396@overload
397def is_fwd_ref(
398 type_: _AnnotationScanType,
399 check_generic: bool = ...,
400 check_for_plain_string: bool = ...,
401) -> TypeGuard[Union[str, ForwardRef]]: ...
402
403
404def is_fwd_ref(
405 type_: _AnnotationScanType,
406 check_generic: bool = False,
407 check_for_plain_string: bool = False,
408) -> TypeGuard[Union[str, ForwardRef]]:
409 if check_for_plain_string and isinstance(type_, str):
410 return True
411 elif isinstance(type_, _type_instances.ForwardRef):
412 return True
413 elif check_generic and is_generic(type_):
414 return any(
415 is_fwd_ref(
416 arg, True, check_for_plain_string=check_for_plain_string
417 )
418 for arg in type_.__args__
419 )
420 else:
421 return False
422
423
424@overload
425def de_optionalize_union_types(type_: str) -> str: ...
426
427
428@overload
429def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ...
430
431
432@overload
433def de_optionalize_union_types(type_: _MatchedOnType) -> _MatchedOnType: ...
434
435
436@overload
437def de_optionalize_union_types(
438 type_: _AnnotationScanType,
439) -> _AnnotationScanType: ...
440
441
442def de_optionalize_union_types(
443 type_: _AnnotationScanType,
444) -> _AnnotationScanType:
445 """Given a type, filter out ``Union`` types that include ``NoneType``
446 to not include the ``NoneType``.
447
448 """
449
450 if is_fwd_ref(type_):
451 return _de_optionalize_fwd_ref_union_types(type_, False)
452
453 elif is_union(type_) and includes_none(type_):
454 typ = {
455 t
456 for t in type_.__args__
457 if t is not NoneType and not is_fwd_none(t)
458 }
459
460 return make_union_type(*typ)
461
462 else:
463 return type_
464
465
466@overload
467def _de_optionalize_fwd_ref_union_types(
468 type_: ForwardRef, return_has_none: Literal[True]
469) -> bool: ...
470
471
472@overload
473def _de_optionalize_fwd_ref_union_types(
474 type_: ForwardRef, return_has_none: Literal[False]
475) -> _AnnotationScanType: ...
476
477
478def _de_optionalize_fwd_ref_union_types(
479 type_: ForwardRef, return_has_none: bool
480) -> Union[_AnnotationScanType, bool]:
481 """return the non-optional type for Optional[], Union[None, ...], x|None,
482 etc. without de-stringifying forward refs.
483
484 unfortunately this seems to require lots of hardcoded heuristics
485
486 """
487
488 annotation = type_.__forward_arg__
489
490 mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
491 if mm:
492 g1 = mm.group(1).split(".")[-1]
493 if g1 == "Optional":
494 return True if return_has_none else ForwardRef(mm.group(2))
495 elif g1 == "Union":
496 if "[" in mm.group(2):
497 # cases like "Union[Dict[str, int], int, None]"
498 elements: list[str] = []
499 current: list[str] = []
500 ignore_comma = 0
501 for char in mm.group(2):
502 if char == "[":
503 ignore_comma += 1
504 elif char == "]":
505 ignore_comma -= 1
506 elif ignore_comma == 0 and char == ",":
507 elements.append("".join(current).strip())
508 current.clear()
509 continue
510 current.append(char)
511 else:
512 elements = re.split(r",\s*", mm.group(2))
513 parts = [ForwardRef(elem) for elem in elements if elem != "None"]
514 if return_has_none:
515 return len(elements) != len(parts)
516 else:
517 return make_union_type(*parts) if parts else Never # type: ignore[return-value] # noqa: E501
518 else:
519 return False if return_has_none else type_
520
521 pipe_tokens = re.split(r"\s*\|\s*", annotation)
522 has_none = "None" in pipe_tokens
523 if return_has_none:
524 return has_none
525 if has_none:
526 anno_str = "|".join(p for p in pipe_tokens if p != "None")
527 return ForwardRef(anno_str) if anno_str else Never # type: ignore[return-value] # noqa: E501
528
529 return type_
530
531
532def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
533 """Make a Union type."""
534
535 return Union[types] # type: ignore
536
537
538def includes_none(type_: Any) -> bool:
539 """Returns if the type annotation ``type_`` allows ``None``.
540
541 This function supports:
542 * forward refs
543 * unions
544 * pep593 - Annotated
545 * pep695 - TypeAliasType (does not support looking into
546 fw reference of other pep695)
547 * NewType
548 * plain types like ``int``, ``None``, etc
549 """
550 if is_fwd_ref(type_):
551 return _de_optionalize_fwd_ref_union_types(type_, True)
552 if is_union(type_):
553 return any(includes_none(t) for t in get_args(type_))
554 if is_pep593(type_):
555 return includes_none(get_args(type_)[0])
556 if is_pep695(type_):
557 return any(includes_none(t) for t in pep695_values(type_))
558 if is_newtype(type_):
559 return includes_none(type_.__supertype__)
560 try:
561 return type_ in (NoneType, None) or is_fwd_none(type_)
562 except TypeError:
563 # if type_ is Column, mapped_column(), etc. the use of "in"
564 # resolves to ``__eq__()`` which then gives us an expression object
565 # that can't resolve to boolean. just catch it all via exception
566 return False
567
568
569def is_a_type(type_: Any) -> bool:
570 return (
571 isinstance(type_, type)
572 or get_origin(type_) is not None
573 or getattr(type_, "__module__", None)
574 in ("typing", "typing_extensions")
575 or type(type_).__mro__[0].__module__ in ("typing", "typing_extensions")
576 )
577
578
579def is_union(type_: Any) -> TypeGuard[ArgsTypeProtocol]:
580 return is_origin_of(type_, "Union", "UnionType")
581
582
583def is_origin_of_cls(
584 type_: Any, class_obj: Union[Tuple[Type[Any], ...], Type[Any]]
585) -> bool:
586 """return True if the given type has an __origin__ that shares a base
587 with the given class"""
588
589 origin = get_origin(type_)
590 if origin is None:
591 return False
592
593 return isinstance(origin, type) and issubclass(origin, class_obj)
594
595
596def is_origin_of(
597 type_: Any, *names: str, module: Optional[str] = None
598) -> bool:
599 """return True if the given type has an __origin__ with the given name
600 and optional module."""
601
602 origin = get_origin(type_)
603 if origin is None:
604 return False
605
606 return origin.__name__ in names and (
607 module is None or origin.__module__.startswith(module)
608 )
609
610
611class DescriptorProto(Protocol):
612 def __get__(self, instance: object, owner: Any) -> Any: ...
613
614 def __set__(self, instance: Any, value: Any) -> None: ...
615
616 def __delete__(self, instance: Any) -> None: ...
617
618
619_DESC = TypeVar("_DESC", bound=DescriptorProto)
620
621
622class DescriptorReference(Generic[_DESC]):
623 """a descriptor that refers to a descriptor.
624
625 used for cases where we need to have an instance variable referring to an
626 object that is itself a descriptor, which typically confuses typing tools
627 as they don't know when they should use ``__get__`` or not when referring
628 to the descriptor assignment as an instance variable. See
629 sqlalchemy.orm.interfaces.PropComparator.prop
630
631 """
632
633 if TYPE_CHECKING:
634
635 def __get__(self, instance: object, owner: Any) -> _DESC: ...
636
637 def __set__(self, instance: Any, value: _DESC) -> None: ...
638
639 def __delete__(self, instance: Any) -> None: ...
640
641
642_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
643
644
645class RODescriptorReference(Generic[_DESC_co]):
646 """a descriptor that refers to a descriptor.
647
648 same as :class:`.DescriptorReference` but is read-only, so that subclasses
649 can define a subtype as the generically contained element
650
651 """
652
653 if TYPE_CHECKING:
654
655 def __get__(self, instance: object, owner: Any) -> _DESC_co: ...
656
657 def __set__(self, instance: Any, value: Any) -> NoReturn: ...
658
659 def __delete__(self, instance: Any) -> NoReturn: ...
660
661
662_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
663
664
665class CallableReference(Generic[_FN]):
666 """a descriptor that refers to a callable.
667
668 works around mypy's limitation of not allowing callables assigned
669 as instance variables
670
671
672 """
673
674 if TYPE_CHECKING:
675
676 def __get__(self, instance: object, owner: Any) -> _FN: ...
677
678 def __set__(self, instance: Any, value: _FN) -> None: ...
679
680 def __delete__(self, instance: Any) -> None: ...
681
682
683class _TypingInstances:
684 def __getattr__(self, key: str) -> tuple[type, ...]:
685 types = tuple(
686 {
687 t
688 for t in [
689 getattr(typing, key, None),
690 getattr(typing_extensions, key, None),
691 ]
692 if t is not None
693 }
694 )
695 if not types:
696 raise AttributeError(key)
697 self.__dict__[key] = types
698 return types
699
700
701_type_tuples = _TypingInstances()
702if TYPE_CHECKING:
703 _type_instances = typing_extensions
704else:
705 _type_instances = _type_tuples
706
707LITERAL_TYPES = _type_tuples.Literal