1# util/typing.py
2# Copyright (C) 2022-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import builtins
12import collections.abc as collections_abc
13import re
14import sys
15from typing import Any
16from typing import Callable
17from typing import cast
18from typing import Dict
19from typing import ForwardRef
20from typing import Generic
21from typing import Iterable
22from typing import Mapping
23from typing import NewType
24from typing import NoReturn
25from typing import Optional
26from typing import overload
27from typing import Protocol
28from typing import Set
29from typing import Tuple
30from typing import Type
31from typing import TYPE_CHECKING
32from typing import TypeVar
33from typing import Union
34
35from . import compat
36
37if True: # zimports removes the tailing comments
38 from typing_extensions import Annotated as Annotated # 3.9
39 from typing_extensions import Concatenate as Concatenate # 3.10
40 from typing_extensions import (
41 dataclass_transform as dataclass_transform, # 3.11,
42 )
43 from typing_extensions import get_args as get_args # 3.10
44 from typing_extensions import get_origin as get_origin # 3.10
45 from typing_extensions import (
46 Literal as Literal,
47 ) # 3.8 but has bugs before 3.10
48 from typing_extensions import NotRequired as NotRequired # 3.11
49 from typing_extensions import ParamSpec as ParamSpec # 3.10
50 from typing_extensions import TypeAlias as TypeAlias # 3.10
51 from typing_extensions import TypeGuard as TypeGuard # 3.10
52 from typing_extensions import TypeVarTuple as TypeVarTuple # 3.11
53 from typing_extensions import Self as Self # 3.11
54 from typing_extensions import TypeAliasType as TypeAliasType # 3.12
55 from typing_extensions import Unpack as Unpack # 3.11
56
57
58_T = TypeVar("_T", bound=Any)
59_KT = TypeVar("_KT")
60_KT_co = TypeVar("_KT_co", covariant=True)
61_KT_contra = TypeVar("_KT_contra", contravariant=True)
62_VT = TypeVar("_VT")
63_VT_co = TypeVar("_VT_co", covariant=True)
64
65TupleAny = Tuple[Any, ...]
66
67
68if compat.py310:
69 # why they took until py310 to put this in stdlib is beyond me,
70 # I've been wanting it since py27
71 from types import NoneType as NoneType
72else:
73 NoneType = type(None) # type: ignore
74
75NoneFwd = ForwardRef("None")
76
77typing_get_args = get_args
78typing_get_origin = get_origin
79
80
81_AnnotationScanType = Union[
82 Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
83]
84
85
86class ArgsTypeProcotol(Protocol):
87 """protocol for types that have ``__args__``
88
89 there's no public interface for this AFAIK
90
91 """
92
93 __args__: Tuple[_AnnotationScanType, ...]
94
95
96class GenericProtocol(Protocol[_T]):
97 """protocol for generic types.
98
99 this since Python.typing _GenericAlias is private
100
101 """
102
103 __args__: Tuple[_AnnotationScanType, ...]
104 __origin__: Type[_T]
105
106 # Python's builtin _GenericAlias has this method, however builtins like
107 # list, dict, etc. do not, even though they have ``__origin__`` and
108 # ``__args__``
109 #
110 # def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
111 # ...
112
113
114# copied from TypeShed, required in order to implement
115# MutableMapping.update()
116class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
117 def keys(self) -> Iterable[_KT]: ...
118
119 def __getitem__(self, __k: _KT) -> _VT_co: ...
120
121
122# work around https://github.com/microsoft/pyright/issues/3025
123_LiteralStar = Literal["*"]
124
125
126def de_stringify_annotation(
127 cls: Type[Any],
128 annotation: _AnnotationScanType,
129 originating_module: str,
130 locals_: Mapping[str, Any],
131 *,
132 str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
133 include_generic: bool = False,
134 _already_seen: Optional[Set[Any]] = None,
135) -> Type[Any]:
136 """Resolve annotations that may be string based into real objects.
137
138 This is particularly important if a module defines "from __future__ import
139 annotations", as everything inside of __annotations__ is a string. We want
140 to at least have generic containers like ``Mapped``, ``Union``, ``List``,
141 etc.
142
143 """
144 # looked at typing.get_type_hints(), looked at pydantic. We need much
145 # less here, and we here try to not use any private typing internals
146 # or construct ForwardRef objects which is documented as something
147 # that should be avoided.
148
149 original_annotation = annotation
150
151 if is_fwd_ref(annotation):
152 annotation = annotation.__forward_arg__
153
154 if isinstance(annotation, str):
155 if str_cleanup_fn:
156 annotation = str_cleanup_fn(annotation, originating_module)
157
158 annotation = eval_expression(
159 annotation, originating_module, locals_=locals_, in_class=cls
160 )
161
162 if (
163 include_generic
164 and is_generic(annotation)
165 and not is_literal(annotation)
166 ):
167 if _already_seen is None:
168 _already_seen = set()
169
170 if annotation in _already_seen:
171 # only occurs recursively. outermost return type
172 # will always be Type.
173 # the element here will be either ForwardRef or
174 # Optional[ForwardRef]
175 return original_annotation # type: ignore
176 else:
177 _already_seen.add(annotation)
178
179 elements = tuple(
180 de_stringify_annotation(
181 cls,
182 elem,
183 originating_module,
184 locals_,
185 str_cleanup_fn=str_cleanup_fn,
186 include_generic=include_generic,
187 _already_seen=_already_seen,
188 )
189 for elem in annotation.__args__
190 )
191
192 return _copy_generic_annotation_with(annotation, elements)
193
194 return annotation # type: ignore
195
196
197def fixup_container_fwd_refs(
198 type_: _AnnotationScanType,
199) -> _AnnotationScanType:
200 """Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')]
201 and similar for list, set
202
203 """
204
205 if (
206 is_generic(type_)
207 and typing_get_origin(type_)
208 in (
209 dict,
210 set,
211 list,
212 collections_abc.MutableSet,
213 collections_abc.MutableMapping,
214 collections_abc.MutableSequence,
215 collections_abc.Mapping,
216 collections_abc.Sequence,
217 )
218 # fight, kick and scream to struggle to tell the difference between
219 # dict[] and typing.Dict[] which DO NOT compare the same and DO NOT
220 # behave the same yet there is NO WAY to distinguish between which type
221 # it is using public attributes
222 and not re.match(
223 "typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_)
224 )
225 ):
226 # compat with py3.10 and earlier
227 return typing_get_origin(type_).__class_getitem__( # type: ignore
228 tuple(
229 [
230 ForwardRef(elem) if isinstance(elem, str) else elem
231 for elem in typing_get_args(type_)
232 ]
233 )
234 )
235 return type_
236
237
238def _copy_generic_annotation_with(
239 annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
240) -> Type[_T]:
241 if hasattr(annotation, "copy_with"):
242 # List, Dict, etc. real generics
243 return annotation.copy_with(elements) # type: ignore
244 else:
245 # Python builtins list, dict, etc.
246 return annotation.__origin__[elements] # type: ignore
247
248
249def eval_expression(
250 expression: str,
251 module_name: str,
252 *,
253 locals_: Optional[Mapping[str, Any]] = None,
254 in_class: Optional[Type[Any]] = None,
255) -> Any:
256 try:
257 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
258 except KeyError as ke:
259 raise NameError(
260 f"Module {module_name} isn't present in sys.modules; can't "
261 f"evaluate expression {expression}"
262 ) from ke
263
264 try:
265 if in_class is not None:
266 cls_namespace = dict(in_class.__dict__)
267 cls_namespace.setdefault(in_class.__name__, in_class)
268
269 # see #10899. We want the locals/globals to take precedence
270 # over the class namespace in this context, even though this
271 # is not the usual way variables would resolve.
272 cls_namespace.update(base_globals)
273
274 annotation = eval(expression, cls_namespace, locals_)
275 else:
276 annotation = eval(expression, base_globals, locals_)
277 except Exception as err:
278 raise NameError(
279 f"Could not de-stringify annotation {expression!r}"
280 ) from err
281 else:
282 return annotation
283
284
285def eval_name_only(
286 name: str,
287 module_name: str,
288 *,
289 locals_: Optional[Mapping[str, Any]] = None,
290) -> Any:
291 if "." in name:
292 return eval_expression(name, module_name, locals_=locals_)
293
294 try:
295 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
296 except KeyError as ke:
297 raise NameError(
298 f"Module {module_name} isn't present in sys.modules; can't "
299 f"resolve name {name}"
300 ) from ke
301
302 # name only, just look in globals. eval() works perfectly fine here,
303 # however we are seeking to have this be faster, as this occurs for
304 # every Mapper[] keyword, etc. depending on configuration
305 try:
306 return base_globals[name]
307 except KeyError as ke:
308 # check in builtins as well to handle `list`, `set` or `dict`, etc.
309 try:
310 return builtins.__dict__[name]
311 except KeyError:
312 pass
313
314 raise NameError(
315 f"Could not locate name {name} in module {module_name}"
316 ) from ke
317
318
319def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
320 try:
321 obj = eval_name_only(name, module_name)
322 except NameError:
323 return name
324 else:
325 return getattr(obj, "__name__", name)
326
327
328def de_stringify_union_elements(
329 cls: Type[Any],
330 annotation: ArgsTypeProcotol,
331 originating_module: str,
332 locals_: Mapping[str, Any],
333 *,
334 str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
335) -> Type[Any]:
336 return make_union_type(
337 *[
338 de_stringify_annotation(
339 cls,
340 anno,
341 originating_module,
342 {},
343 str_cleanup_fn=str_cleanup_fn,
344 )
345 for anno in annotation.__args__
346 ]
347 )
348
349
350def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
351 return type_ is not None and typing_get_origin(type_) is Annotated
352
353
354def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
355 return isinstance(obj, collections_abc.Iterable) and not isinstance(
356 obj, (str, bytes)
357 )
358
359
360def is_literal(type_: _AnnotationScanType) -> bool:
361 return get_origin(type_) is Literal
362
363
364def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
365 return hasattr(type_, "__supertype__")
366
367 # doesn't work in 3.8, 3.7 as it passes a closure, not an
368 # object instance
369 # return isinstance(type_, NewType)
370
371
372def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
373 return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
374
375
376def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
377 return isinstance(type_, TypeAliasType)
378
379
380def flatten_newtype(type_: NewType) -> Type[Any]:
381 super_type = type_.__supertype__
382 while is_newtype(super_type):
383 super_type = super_type.__supertype__
384 return super_type # type: ignore[return-value]
385
386
387def is_fwd_ref(
388 type_: _AnnotationScanType, check_generic: bool = False
389) -> TypeGuard[ForwardRef]:
390 if isinstance(type_, ForwardRef):
391 return True
392 elif check_generic and is_generic(type_):
393 return any(is_fwd_ref(arg, True) for arg in type_.__args__)
394 else:
395 return False
396
397
398@overload
399def de_optionalize_union_types(type_: str) -> str: ...
400
401
402@overload
403def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ...
404
405
406@overload
407def de_optionalize_union_types(
408 type_: _AnnotationScanType,
409) -> _AnnotationScanType: ...
410
411
412def de_optionalize_union_types(
413 type_: _AnnotationScanType,
414) -> _AnnotationScanType:
415 """Given a type, filter out ``Union`` types that include ``NoneType``
416 to not include the ``NoneType``.
417
418 """
419
420 if is_fwd_ref(type_):
421 return de_optionalize_fwd_ref_union_types(type_)
422
423 elif is_optional(type_):
424 typ = set(type_.__args__)
425
426 typ.discard(NoneType)
427 typ.discard(NoneFwd)
428
429 return make_union_type(*typ)
430
431 else:
432 return type_
433
434
435def de_optionalize_fwd_ref_union_types(
436 type_: ForwardRef,
437) -> _AnnotationScanType:
438 """return the non-optional type for Optional[], Union[None, ...], x|None,
439 etc. without de-stringifying forward refs.
440
441 unfortunately this seems to require lots of hardcoded heuristics
442
443 """
444
445 annotation = type_.__forward_arg__
446
447 mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
448 if mm:
449 if mm.group(1) == "Optional":
450 return ForwardRef(mm.group(2))
451 elif mm.group(1) == "Union":
452 elements = re.split(r",\s*", mm.group(2))
453 return make_union_type(
454 *[ForwardRef(elem) for elem in elements if elem != "None"]
455 )
456 else:
457 return type_
458
459 pipe_tokens = re.split(r"\s*\|\s*", annotation)
460 if "None" in pipe_tokens:
461 return ForwardRef("|".join(p for p in pipe_tokens if p != "None"))
462
463 return type_
464
465
466def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
467 """Make a Union type.
468
469 This is needed by :func:`.de_optionalize_union_types` which removes
470 ``NoneType`` from a ``Union``.
471
472 """
473 return cast(Any, Union).__getitem__(types) # type: ignore
474
475
476def expand_unions(
477 type_: Type[Any], include_union: bool = False, discard_none: bool = False
478) -> Tuple[Type[Any], ...]:
479 """Return a type as a tuple of individual types, expanding for
480 ``Union`` types."""
481
482 if is_union(type_):
483 typ = set(type_.__args__)
484
485 if discard_none:
486 typ.discard(NoneType)
487
488 if include_union:
489 return (type_,) + tuple(typ) # type: ignore
490 else:
491 return tuple(typ) # type: ignore
492 else:
493 return (type_,)
494
495
496def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
497 return is_origin_of(
498 type_,
499 "Optional",
500 "Union",
501 "UnionType",
502 )
503
504
505def is_optional_union(type_: Any) -> bool:
506 return is_optional(type_) and NoneType in typing_get_args(type_)
507
508
509def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
510 return is_origin_of(type_, "Union")
511
512
513def is_origin_of_cls(
514 type_: Any, class_obj: Union[Tuple[Type[Any], ...], Type[Any]]
515) -> bool:
516 """return True if the given type has an __origin__ that shares a base
517 with the given class"""
518
519 origin = typing_get_origin(type_)
520 if origin is None:
521 return False
522
523 return isinstance(origin, type) and issubclass(origin, class_obj)
524
525
526def is_origin_of(
527 type_: Any, *names: str, module: Optional[str] = None
528) -> bool:
529 """return True if the given type has an __origin__ with the given name
530 and optional module."""
531
532 origin = typing_get_origin(type_)
533 if origin is None:
534 return False
535
536 return _get_type_name(origin) in names and (
537 module is None or origin.__module__.startswith(module)
538 )
539
540
541def _get_type_name(type_: Type[Any]) -> str:
542 if compat.py310:
543 return type_.__name__
544 else:
545 typ_name = getattr(type_, "__name__", None)
546 if typ_name is None:
547 typ_name = getattr(type_, "_name", None)
548
549 return typ_name # type: ignore
550
551
552class DescriptorProto(Protocol):
553 def __get__(self, instance: object, owner: Any) -> Any: ...
554
555 def __set__(self, instance: Any, value: Any) -> None: ...
556
557 def __delete__(self, instance: Any) -> None: ...
558
559
560_DESC = TypeVar("_DESC", bound=DescriptorProto)
561
562
563class DescriptorReference(Generic[_DESC]):
564 """a descriptor that refers to a descriptor.
565
566 used for cases where we need to have an instance variable referring to an
567 object that is itself a descriptor, which typically confuses typing tools
568 as they don't know when they should use ``__get__`` or not when referring
569 to the descriptor assignment as an instance variable. See
570 sqlalchemy.orm.interfaces.PropComparator.prop
571
572 """
573
574 if TYPE_CHECKING:
575
576 def __get__(self, instance: object, owner: Any) -> _DESC: ...
577
578 def __set__(self, instance: Any, value: _DESC) -> None: ...
579
580 def __delete__(self, instance: Any) -> None: ...
581
582
583_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
584
585
586class RODescriptorReference(Generic[_DESC_co]):
587 """a descriptor that refers to a descriptor.
588
589 same as :class:`.DescriptorReference` but is read-only, so that subclasses
590 can define a subtype as the generically contained element
591
592 """
593
594 if TYPE_CHECKING:
595
596 def __get__(self, instance: object, owner: Any) -> _DESC_co: ...
597
598 def __set__(self, instance: Any, value: Any) -> NoReturn: ...
599
600 def __delete__(self, instance: Any) -> NoReturn: ...
601
602
603_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
604
605
606class CallableReference(Generic[_FN]):
607 """a descriptor that refers to a callable.
608
609 works around mypy's limitation of not allowing callables assigned
610 as instance variables
611
612
613 """
614
615 if TYPE_CHECKING:
616
617 def __get__(self, instance: object, owner: Any) -> _FN: ...
618
619 def __set__(self, instance: Any, value: _FN) -> None: ...
620
621 def __delete__(self, instance: Any) -> None: ...
622
623
624# $def ro_descriptor_reference(fn: Callable[])