1# util/typing.py
2# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import builtins
12from collections import deque
13import collections.abc as collections_abc
14import re
15import sys
16from types import NoneType
17import typing
18from typing import Any
19from typing import Callable
20from typing import Dict
21from typing import ForwardRef
22from typing import Generic
23from typing import get_args
24from typing import get_origin
25from typing import Iterable
26from typing import Literal
27from typing import Mapping
28from typing import NewType
29from typing import NoReturn
30from typing import Optional
31from typing import overload
32from typing import Protocol
33from typing import Set
34from typing import Tuple
35from typing import Type
36from typing import TYPE_CHECKING
37from typing import TypeGuard
38from typing import TypeVar
39from typing import Union
40
41import typing_extensions
42
43from . import compat
44
45if True: # zimports removes the tailing comments
46 from typing_extensions import (
47 dataclass_transform as dataclass_transform, # 3.11,
48 )
49 from typing_extensions import NotRequired as NotRequired # 3.11
50 from typing_extensions import TypeVarTuple as TypeVarTuple # 3.11
51 from typing_extensions import Self as Self # 3.11
52 from typing_extensions import TypeAliasType as TypeAliasType # 3.12
53 from typing_extensions import Unpack as Unpack # 3.11
54 from typing_extensions import Never as Never # 3.11
55 from typing_extensions import LiteralString as LiteralString # 3.11
56
57
58_T = TypeVar("_T", bound=Any)
59_KT = TypeVar("_KT")
60_KT_co = TypeVar("_KT_co", covariant=True)
61_KT_contra = TypeVar("_KT_contra", contravariant=True)
62_VT = TypeVar("_VT")
63_VT_co = TypeVar("_VT_co", covariant=True)
64
65TupleAny = Tuple[Any, ...]
66
67
68def is_fwd_none(typ: Any) -> bool:
69 return isinstance(typ, ForwardRef) and typ.__forward_arg__ == "None"
70
71
72_AnnotationScanType = Union[
73 Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
74]
75
76_MatchedOnType = Union[
77 "GenericProtocol[Any]", TypeAliasType, NewType, Type[Any]
78]
79
80
81class ArgsTypeProtocol(Protocol):
82 """protocol for types that have ``__args__``
83
84 there's no public interface for this AFAIK
85
86 """
87
88 __args__: Tuple[_AnnotationScanType, ...]
89
90
91class GenericProtocol(Protocol[_T]):
92 """protocol for generic types.
93
94 this since Python.typing _GenericAlias is private
95
96 """
97
98 __args__: Tuple[_AnnotationScanType, ...]
99 __origin__: Type[_T]
100
101 # Python's builtin _GenericAlias has this method, however builtins like
102 # list, dict, etc. do not, even though they have ``__origin__`` and
103 # ``__args__``
104 #
105 # def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
106 # ...
107
108
109# copied from TypeShed, required in order to implement
110# MutableMapping.update()
111class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
112 def keys(self) -> Iterable[_KT]: ...
113
114 def __getitem__(self, __k: _KT) -> _VT_co: ...
115
116
117# work around https://github.com/microsoft/pyright/issues/3025
118_LiteralStar = Literal["*"]
119
120
121def de_stringify_annotation(
122 cls: Type[Any],
123 annotation: _AnnotationScanType,
124 originating_module: str,
125 locals_: Mapping[str, Any],
126 *,
127 str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
128 include_generic: bool = False,
129 _already_seen: Optional[Set[Any]] = None,
130) -> Type[Any]:
131 """Resolve annotations that may be string based into real objects.
132
133 This is particularly important if a module defines "from __future__ import
134 annotations", as everything inside of __annotations__ is a string. We want
135 to at least have generic containers like ``Mapped``, ``Union``, ``List``,
136 etc.
137
138 """
139 # looked at typing.get_type_hints(), looked at pydantic. We need much
140 # less here, and we here try to not use any private typing internals
141 # or construct ForwardRef objects which is documented as something
142 # that should be avoided.
143
144 original_annotation = annotation
145
146 if is_fwd_ref(annotation):
147 annotation = annotation.__forward_arg__
148
149 if isinstance(annotation, str):
150 if str_cleanup_fn:
151 annotation = str_cleanup_fn(annotation, originating_module)
152
153 annotation = eval_expression(
154 annotation, originating_module, locals_=locals_, in_class=cls
155 )
156
157 if (
158 include_generic
159 and is_generic(annotation)
160 and not is_literal(annotation)
161 ):
162 if _already_seen is None:
163 _already_seen = set()
164
165 if annotation in _already_seen:
166 # only occurs recursively. outermost return type
167 # will always be Type.
168 # the element here will be either ForwardRef or
169 # Optional[ForwardRef]
170 return original_annotation # type: ignore
171 else:
172 _already_seen.add(annotation)
173
174 elements = tuple(
175 de_stringify_annotation(
176 cls,
177 elem,
178 originating_module,
179 locals_,
180 str_cleanup_fn=str_cleanup_fn,
181 include_generic=include_generic,
182 _already_seen=_already_seen,
183 )
184 for elem in annotation.__args__
185 )
186
187 return _copy_generic_annotation_with(annotation, elements)
188
189 return annotation # type: ignore
190
191
192def fixup_container_fwd_refs(
193 type_: _AnnotationScanType,
194) -> _AnnotationScanType:
195 """Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')]
196 and similar for list, set
197
198 """
199
200 if (
201 is_generic(type_)
202 and get_origin(type_)
203 in (
204 dict,
205 set,
206 list,
207 collections_abc.MutableSet,
208 collections_abc.MutableMapping,
209 collections_abc.MutableSequence,
210 collections_abc.Mapping,
211 collections_abc.Sequence,
212 )
213 # fight, kick and scream to struggle to tell the difference between
214 # dict[] and typing.Dict[] which DO NOT compare the same and DO NOT
215 # behave the same yet there is NO WAY to distinguish between which type
216 # it is using public attributes
217 and not re.match(
218 "typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_)
219 )
220 ):
221 # compat with py3.10 and earlier
222 return get_origin(type_).__class_getitem__( # type: ignore
223 tuple(
224 [
225 ForwardRef(elem) if isinstance(elem, str) else elem
226 for elem in get_args(type_)
227 ]
228 )
229 )
230 return type_
231
232
233def _copy_generic_annotation_with(
234 annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
235) -> Type[_T]:
236 if hasattr(annotation, "copy_with"):
237 # List, Dict, etc. real generics
238 return annotation.copy_with(elements) # type: ignore
239 else:
240 # Python builtins list, dict, etc.
241 return annotation.__origin__[elements] # type: ignore
242
243
244def eval_expression(
245 expression: str,
246 module_name: str,
247 *,
248 locals_: Optional[Mapping[str, Any]] = None,
249 in_class: Optional[Type[Any]] = None,
250) -> Any:
251 try:
252 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
253 except KeyError as ke:
254 raise NameError(
255 f"Module {module_name} isn't present in sys.modules; can't "
256 f"evaluate expression {expression}"
257 ) from ke
258
259 try:
260 if in_class is not None:
261 cls_namespace = dict(in_class.__dict__)
262 cls_namespace.setdefault(in_class.__name__, in_class)
263
264 # see #10899. We want the locals/globals to take precedence
265 # over the class namespace in this context, even though this
266 # is not the usual way variables would resolve.
267 cls_namespace.update(base_globals)
268
269 annotation = eval(expression, cls_namespace, locals_)
270 else:
271 annotation = eval(expression, base_globals, locals_)
272 except Exception as err:
273 raise NameError(
274 f"Could not de-stringify annotation {expression!r}"
275 ) from err
276 else:
277 return annotation
278
279
280def eval_name_only(
281 name: str,
282 module_name: str,
283 *,
284 locals_: Optional[Mapping[str, Any]] = None,
285) -> Any:
286 if "." in name:
287 return eval_expression(name, module_name, locals_=locals_)
288
289 try:
290 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
291 except KeyError as ke:
292 raise NameError(
293 f"Module {module_name} isn't present in sys.modules; can't "
294 f"resolve name {name}"
295 ) from ke
296
297 # name only, just look in globals. eval() works perfectly fine here,
298 # however we are seeking to have this be faster, as this occurs for
299 # every Mapper[] keyword, etc. depending on configuration
300 try:
301 return base_globals[name]
302 except KeyError as ke:
303 # check in builtins as well to handle `list`, `set` or `dict`, etc.
304 try:
305 return builtins.__dict__[name]
306 except KeyError:
307 pass
308
309 raise NameError(
310 f"Could not locate name {name} in module {module_name}"
311 ) from ke
312
313
314def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
315 try:
316 obj = eval_name_only(name, module_name)
317 except NameError:
318 return name
319 else:
320 return getattr(obj, "__name__", name)
321
322
323def is_pep593(type_: Optional[Any]) -> bool:
324 return type_ is not None and get_origin(type_) in _type_tuples.Annotated
325
326
327def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
328 return isinstance(obj, collections_abc.Iterable) and not isinstance(
329 obj, (str, bytes)
330 )
331
332
333def is_literal(type_: Any) -> bool:
334 return get_origin(type_) in _type_tuples.Literal
335
336
337def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
338 return isinstance(type_, _type_tuples.NewType)
339
340
341def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
342 return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
343
344
345def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
346 # NOTE: a generic TAT does not instance check as TypeAliasType outside of
347 # python 3.10. For sqlalchemy use cases it's fine to consider it a TAT
348 # though.
349 # NOTE: things seems to work also without this additional check
350 if is_generic(type_):
351 return is_pep695(type_.__origin__)
352 return isinstance(type_, _type_instances.TypeAliasType)
353
354
355def pep695_values(type_: _AnnotationScanType) -> Set[Any]:
356 """Extracts the value from a TypeAliasType, recursively exploring unions
357 and inner TypeAliasType to flatten them into a single set.
358
359 Forward references are not evaluated, so no recursive exploration happens
360 into them.
361 """
362 _seen = set()
363
364 def recursive_value(inner_type):
365 if inner_type in _seen:
366 # recursion are not supported (at least it's flagged as
367 # an error by pyright). Just avoid infinite loop
368 return inner_type
369 _seen.add(inner_type)
370 if not is_pep695(inner_type):
371 return inner_type
372 value = inner_type.__value__
373 if not is_union(value):
374 return value
375 return [recursive_value(t) for t in value.__args__]
376
377 res = recursive_value(type_)
378 if isinstance(res, list):
379 types = set()
380 stack = deque(res)
381 while stack:
382 t = stack.popleft()
383 if isinstance(t, list):
384 stack.extend(t)
385 else:
386 types.add(None if t is NoneType or is_fwd_none(t) else t)
387 return types
388 else:
389 return {res}
390
391
392@overload
393def is_fwd_ref(
394 type_: _AnnotationScanType,
395 check_generic: bool = ...,
396 check_for_plain_string: Literal[False] = ...,
397) -> TypeGuard[ForwardRef]: ...
398
399
400@overload
401def is_fwd_ref(
402 type_: _AnnotationScanType,
403 check_generic: bool = ...,
404 check_for_plain_string: bool = ...,
405) -> TypeGuard[Union[str, ForwardRef]]: ...
406
407
408def is_fwd_ref(
409 type_: _AnnotationScanType,
410 check_generic: bool = False,
411 check_for_plain_string: bool = False,
412) -> TypeGuard[Union[str, ForwardRef]]:
413 if check_for_plain_string and isinstance(type_, str):
414 return True
415 elif isinstance(type_, _type_instances.ForwardRef):
416 return True
417 elif check_generic and is_generic(type_):
418 return any(
419 is_fwd_ref(
420 arg, True, check_for_plain_string=check_for_plain_string
421 )
422 for arg in type_.__args__
423 )
424 else:
425 return False
426
427
428@overload
429def de_optionalize_union_types(type_: str) -> str: ...
430
431
432@overload
433def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ...
434
435
436@overload
437def de_optionalize_union_types(type_: _MatchedOnType) -> _MatchedOnType: ...
438
439
440@overload
441def de_optionalize_union_types(
442 type_: _AnnotationScanType,
443) -> _AnnotationScanType: ...
444
445
446def de_optionalize_union_types(
447 type_: _AnnotationScanType,
448) -> _AnnotationScanType:
449 """Given a type, filter out ``Union`` types that include ``NoneType``
450 to not include the ``NoneType``.
451
452 """
453
454 if is_fwd_ref(type_):
455 return _de_optionalize_fwd_ref_union_types(type_, False)
456
457 elif is_union(type_) and includes_none(type_):
458 typ = {
459 t
460 for t in type_.__args__
461 if t is not NoneType and not is_fwd_none(t)
462 }
463
464 return make_union_type(*typ)
465
466 else:
467 return type_
468
469
470@overload
471def _de_optionalize_fwd_ref_union_types(
472 type_: ForwardRef, return_has_none: Literal[True]
473) -> bool: ...
474
475
476@overload
477def _de_optionalize_fwd_ref_union_types(
478 type_: ForwardRef, return_has_none: Literal[False]
479) -> _AnnotationScanType: ...
480
481
482def _de_optionalize_fwd_ref_union_types(
483 type_: ForwardRef, return_has_none: bool
484) -> Union[_AnnotationScanType, bool]:
485 """return the non-optional type for Optional[], Union[None, ...], x|None,
486 etc. without de-stringifying forward refs.
487
488 unfortunately this seems to require lots of hardcoded heuristics
489
490 """
491
492 annotation = type_.__forward_arg__
493
494 mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
495 if mm:
496 g1 = mm.group(1).split(".")[-1]
497 if g1 == "Optional":
498 return True if return_has_none else ForwardRef(mm.group(2))
499 elif g1 == "Union":
500 if "[" in mm.group(2):
501 # cases like "Union[Dict[str, int], int, None]"
502 elements: list[str] = []
503 current: list[str] = []
504 ignore_comma = 0
505 for char in mm.group(2):
506 if char == "[":
507 ignore_comma += 1
508 elif char == "]":
509 ignore_comma -= 1
510 elif ignore_comma == 0 and char == ",":
511 elements.append("".join(current).strip())
512 current.clear()
513 continue
514 current.append(char)
515 else:
516 elements = re.split(r",\s*", mm.group(2))
517 parts = [ForwardRef(elem) for elem in elements if elem != "None"]
518 if return_has_none:
519 return len(elements) != len(parts)
520 else:
521 return make_union_type(*parts) if parts else Never # type: ignore[return-value] # noqa: E501
522 else:
523 return False if return_has_none else type_
524
525 pipe_tokens = re.split(r"\s*\|\s*", annotation)
526 has_none = "None" in pipe_tokens
527 if return_has_none:
528 return has_none
529 if has_none:
530 anno_str = "|".join(p for p in pipe_tokens if p != "None")
531 return ForwardRef(anno_str) if anno_str else Never # type: ignore[return-value] # noqa: E501
532
533 return type_
534
535
536def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
537 """Make a Union type."""
538
539 return Union[types] # type: ignore
540
541
542def includes_none(type_: Any) -> bool:
543 """Returns if the type annotation ``type_`` allows ``None``.
544
545 This function supports:
546 * forward refs
547 * unions
548 * pep593 - Annotated
549 * pep695 - TypeAliasType (does not support looking into
550 fw reference of other pep695)
551 * NewType
552 * plain types like ``int``, ``None``, etc
553 """
554 if is_fwd_ref(type_):
555 return _de_optionalize_fwd_ref_union_types(type_, True)
556 if is_union(type_):
557 return any(includes_none(t) for t in get_args(type_))
558 if is_pep593(type_):
559 return includes_none(get_args(type_)[0])
560 if is_pep695(type_):
561 return any(includes_none(t) for t in pep695_values(type_))
562 if is_newtype(type_):
563 return includes_none(type_.__supertype__)
564 try:
565 return type_ in (NoneType, None) or is_fwd_none(type_)
566 except TypeError:
567 # if type_ is Column, mapped_column(), etc. the use of "in"
568 # resolves to ``__eq__()`` which then gives us an expression object
569 # that can't resolve to boolean. just catch it all via exception
570 return False
571
572
573def is_a_type(type_: Any) -> bool:
574 return (
575 isinstance(type_, type)
576 or get_origin(type_) is not None
577 or getattr(type_, "__module__", None)
578 in ("typing", "typing_extensions")
579 or type(type_).__mro__[0].__module__ in ("typing", "typing_extensions")
580 )
581
582
583def is_union(type_: Any) -> TypeGuard[ArgsTypeProtocol]:
584 return is_origin_of(type_, "Union", "UnionType")
585
586
587def is_origin_of_cls(
588 type_: Any, class_obj: Union[Tuple[Type[Any], ...], Type[Any]]
589) -> bool:
590 """return True if the given type has an __origin__ that shares a base
591 with the given class"""
592
593 origin = get_origin(type_)
594 if origin is None:
595 return False
596
597 return isinstance(origin, type) and issubclass(origin, class_obj)
598
599
600def is_origin_of(
601 type_: Any, *names: str, module: Optional[str] = None
602) -> bool:
603 """return True if the given type has an __origin__ with the given name
604 and optional module."""
605
606 origin = get_origin(type_)
607 if origin is None:
608 return False
609
610 return origin.__name__ in names and (
611 module is None or origin.__module__.startswith(module)
612 )
613
614
615class DescriptorProto(Protocol):
616 def __get__(self, instance: object, owner: Any) -> Any: ...
617
618 def __set__(self, instance: Any, value: Any) -> None: ...
619
620 def __delete__(self, instance: Any) -> None: ...
621
622
623_DESC = TypeVar("_DESC", bound=DescriptorProto)
624
625
626class DescriptorReference(Generic[_DESC]):
627 """a descriptor that refers to a descriptor.
628
629 used for cases where we need to have an instance variable referring to an
630 object that is itself a descriptor, which typically confuses typing tools
631 as they don't know when they should use ``__get__`` or not when referring
632 to the descriptor assignment as an instance variable. See
633 sqlalchemy.orm.interfaces.PropComparator.prop
634
635 """
636
637 if TYPE_CHECKING:
638
639 def __get__(self, instance: object, owner: Any) -> _DESC: ...
640
641 def __set__(self, instance: Any, value: _DESC) -> None: ...
642
643 def __delete__(self, instance: Any) -> None: ...
644
645
646_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
647
648
649class RODescriptorReference(Generic[_DESC_co]):
650 """a descriptor that refers to a descriptor.
651
652 same as :class:`.DescriptorReference` but is read-only, so that subclasses
653 can define a subtype as the generically contained element
654
655 """
656
657 if TYPE_CHECKING:
658
659 def __get__(self, instance: object, owner: Any) -> _DESC_co: ...
660
661 def __set__(self, instance: Any, value: Any) -> NoReturn: ...
662
663 def __delete__(self, instance: Any) -> NoReturn: ...
664
665
666_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
667
668
669class CallableReference(Generic[_FN]):
670 """a descriptor that refers to a callable.
671
672 works around mypy's limitation of not allowing callables assigned
673 as instance variables
674
675
676 """
677
678 if TYPE_CHECKING:
679
680 def __get__(self, instance: object, owner: Any) -> _FN: ...
681
682 def __set__(self, instance: Any, value: _FN) -> None: ...
683
684 def __delete__(self, instance: Any) -> None: ...
685
686
687class _TypingInstances:
688 def __getattr__(self, key: str) -> tuple[type, ...]:
689 types = tuple(
690 {
691 t
692 for t in [
693 getattr(typing, key, None),
694 getattr(typing_extensions, key, None),
695 ]
696 if t is not None
697 }
698 )
699 if not types:
700 raise AttributeError(key)
701 self.__dict__[key] = types
702 return types
703
704
705_type_tuples = _TypingInstances()
706if TYPE_CHECKING:
707 _type_instances = typing_extensions
708else:
709 _type_instances = _type_tuples
710
711LITERAL_TYPES = _type_tuples.Literal