1# util/typing.py
2# Copyright (C) 2022-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import builtins
12import collections.abc as collections_abc
13import re
14import sys
15from typing import Any
16from typing import Callable
17from typing import cast
18from typing import Dict
19from typing import ForwardRef
20from typing import Generic
21from typing import Iterable
22from typing import Mapping
23from typing import NewType
24from typing import NoReturn
25from typing import Optional
26from typing import overload
27from typing import Protocol
28from typing import Set
29from typing import Tuple
30from typing import Type
31from typing import TYPE_CHECKING
32from typing import TypeVar
33from typing import Union
34
35from . import compat
36
37if True: # zimports removes the tailing comments
38 from typing_extensions import Annotated as Annotated # 3.9
39 from typing_extensions import Concatenate as Concatenate # 3.10
40 from typing_extensions import (
41 dataclass_transform as dataclass_transform, # 3.11,
42 )
43 from typing_extensions import get_args as get_args # 3.10
44 from typing_extensions import get_origin as get_origin # 3.10
45 from typing_extensions import (
46 Literal as Literal,
47 ) # 3.8 but has bugs before 3.10
48 from typing_extensions import NotRequired as NotRequired # 3.11
49 from typing_extensions import ParamSpec as ParamSpec # 3.10
50 from typing_extensions import TypeAlias as TypeAlias # 3.10
51 from typing_extensions import TypeGuard as TypeGuard # 3.10
52 from typing_extensions import TypeVarTuple as TypeVarTuple # 3.11
53 from typing_extensions import Self as Self # 3.11
54 from typing_extensions import TypeAliasType as TypeAliasType # 3.12
55 from typing_extensions import Unpack as Unpack # 3.11
56
57
58_T = TypeVar("_T", bound=Any)
59_KT = TypeVar("_KT")
60_KT_co = TypeVar("_KT_co", covariant=True)
61_KT_contra = TypeVar("_KT_contra", contravariant=True)
62_VT = TypeVar("_VT")
63_VT_co = TypeVar("_VT_co", covariant=True)
64
65TupleAny = Tuple[Any, ...]
66
67
68if compat.py310:
69 # why they took until py310 to put this in stdlib is beyond me,
70 # I've been wanting it since py27
71 from types import NoneType as NoneType
72else:
73 NoneType = type(None) # type: ignore
74
75NoneFwd = ForwardRef("None")
76
77typing_get_args = get_args
78typing_get_origin = get_origin
79
80
81_AnnotationScanType = Union[
82 Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
83]
84
85
86class ArgsTypeProcotol(Protocol):
87 """protocol for types that have ``__args__``
88
89 there's no public interface for this AFAIK
90
91 """
92
93 __args__: Tuple[_AnnotationScanType, ...]
94
95
96class GenericProtocol(Protocol[_T]):
97 """protocol for generic types.
98
99 this since Python.typing _GenericAlias is private
100
101 """
102
103 __args__: Tuple[_AnnotationScanType, ...]
104 __origin__: Type[_T]
105
106 # Python's builtin _GenericAlias has this method, however builtins like
107 # list, dict, etc. do not, even though they have ``__origin__`` and
108 # ``__args__``
109 #
110 # def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
111 # ...
112
113
114# copied from TypeShed, required in order to implement
115# MutableMapping.update()
116class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
117 def keys(self) -> Iterable[_KT]: ...
118
119 def __getitem__(self, __k: _KT) -> _VT_co: ...
120
121
122# work around https://github.com/microsoft/pyright/issues/3025
123_LiteralStar = Literal["*"]
124
125
126def de_stringify_annotation(
127 cls: Type[Any],
128 annotation: _AnnotationScanType,
129 originating_module: str,
130 locals_: Mapping[str, Any],
131 *,
132 str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
133 include_generic: bool = False,
134 _already_seen: Optional[Set[Any]] = None,
135) -> Type[Any]:
136 """Resolve annotations that may be string based into real objects.
137
138 This is particularly important if a module defines "from __future__ import
139 annotations", as everything inside of __annotations__ is a string. We want
140 to at least have generic containers like ``Mapped``, ``Union``, ``List``,
141 etc.
142
143 """
144 # looked at typing.get_type_hints(), looked at pydantic. We need much
145 # less here, and we here try to not use any private typing internals
146 # or construct ForwardRef objects which is documented as something
147 # that should be avoided.
148
149 original_annotation = annotation
150
151 if is_fwd_ref(annotation):
152 annotation = annotation.__forward_arg__
153
154 if isinstance(annotation, str):
155 if str_cleanup_fn:
156 annotation = str_cleanup_fn(annotation, originating_module)
157
158 annotation = eval_expression(
159 annotation, originating_module, locals_=locals_, in_class=cls
160 )
161
162 if (
163 include_generic
164 and is_generic(annotation)
165 and not is_literal(annotation)
166 ):
167 if _already_seen is None:
168 _already_seen = set()
169
170 if annotation in _already_seen:
171 # only occurs recursively. outermost return type
172 # will always be Type.
173 # the element here will be either ForwardRef or
174 # Optional[ForwardRef]
175 return original_annotation # type: ignore
176 else:
177 _already_seen.add(annotation)
178
179 elements = tuple(
180 de_stringify_annotation(
181 cls,
182 elem,
183 originating_module,
184 locals_,
185 str_cleanup_fn=str_cleanup_fn,
186 include_generic=include_generic,
187 _already_seen=_already_seen,
188 )
189 for elem in annotation.__args__
190 )
191
192 return _copy_generic_annotation_with(annotation, elements)
193 return annotation # type: ignore
194
195
196def _copy_generic_annotation_with(
197 annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
198) -> Type[_T]:
199 if hasattr(annotation, "copy_with"):
200 # List, Dict, etc. real generics
201 return annotation.copy_with(elements) # type: ignore
202 else:
203 # Python builtins list, dict, etc.
204 return annotation.__origin__[elements] # type: ignore
205
206
207def eval_expression(
208 expression: str,
209 module_name: str,
210 *,
211 locals_: Optional[Mapping[str, Any]] = None,
212 in_class: Optional[Type[Any]] = None,
213) -> Any:
214 try:
215 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
216 except KeyError as ke:
217 raise NameError(
218 f"Module {module_name} isn't present in sys.modules; can't "
219 f"evaluate expression {expression}"
220 ) from ke
221
222 try:
223 if in_class is not None:
224 cls_namespace = dict(in_class.__dict__)
225 cls_namespace.setdefault(in_class.__name__, in_class)
226
227 # see #10899. We want the locals/globals to take precedence
228 # over the class namespace in this context, even though this
229 # is not the usual way variables would resolve.
230 cls_namespace.update(base_globals)
231
232 annotation = eval(expression, cls_namespace, locals_)
233 else:
234 annotation = eval(expression, base_globals, locals_)
235 except Exception as err:
236 raise NameError(
237 f"Could not de-stringify annotation {expression!r}"
238 ) from err
239 else:
240 return annotation
241
242
243def eval_name_only(
244 name: str,
245 module_name: str,
246 *,
247 locals_: Optional[Mapping[str, Any]] = None,
248) -> Any:
249 if "." in name:
250 return eval_expression(name, module_name, locals_=locals_)
251
252 try:
253 base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
254 except KeyError as ke:
255 raise NameError(
256 f"Module {module_name} isn't present in sys.modules; can't "
257 f"resolve name {name}"
258 ) from ke
259
260 # name only, just look in globals. eval() works perfectly fine here,
261 # however we are seeking to have this be faster, as this occurs for
262 # every Mapper[] keyword, etc. depending on configuration
263 try:
264 return base_globals[name]
265 except KeyError as ke:
266 # check in builtins as well to handle `list`, `set` or `dict`, etc.
267 try:
268 return builtins.__dict__[name]
269 except KeyError:
270 pass
271
272 raise NameError(
273 f"Could not locate name {name} in module {module_name}"
274 ) from ke
275
276
277def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
278 try:
279 obj = eval_name_only(name, module_name)
280 except NameError:
281 return name
282 else:
283 return getattr(obj, "__name__", name)
284
285
286def de_stringify_union_elements(
287 cls: Type[Any],
288 annotation: ArgsTypeProcotol,
289 originating_module: str,
290 locals_: Mapping[str, Any],
291 *,
292 str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
293) -> Type[Any]:
294 return make_union_type(
295 *[
296 de_stringify_annotation(
297 cls,
298 anno,
299 originating_module,
300 {},
301 str_cleanup_fn=str_cleanup_fn,
302 )
303 for anno in annotation.__args__
304 ]
305 )
306
307
308def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
309 return type_ is not None and typing_get_origin(type_) is Annotated
310
311
312def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
313 return isinstance(obj, collections_abc.Iterable) and not isinstance(
314 obj, (str, bytes)
315 )
316
317
318def is_literal(type_: _AnnotationScanType) -> bool:
319 return get_origin(type_) is Literal
320
321
322def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
323 return hasattr(type_, "__supertype__")
324
325 # doesn't work in 3.8, 3.7 as it passes a closure, not an
326 # object instance
327 # return isinstance(type_, NewType)
328
329
330def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
331 return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
332
333
334def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
335 return isinstance(type_, TypeAliasType)
336
337
338def flatten_newtype(type_: NewType) -> Type[Any]:
339 super_type = type_.__supertype__
340 while is_newtype(super_type):
341 super_type = super_type.__supertype__
342 return super_type # type: ignore[return-value]
343
344
345def is_fwd_ref(
346 type_: _AnnotationScanType, check_generic: bool = False
347) -> TypeGuard[ForwardRef]:
348 if isinstance(type_, ForwardRef):
349 return True
350 elif check_generic and is_generic(type_):
351 return any(is_fwd_ref(arg, True) for arg in type_.__args__)
352 else:
353 return False
354
355
356@overload
357def de_optionalize_union_types(type_: str) -> str: ...
358
359
360@overload
361def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ...
362
363
364@overload
365def de_optionalize_union_types(
366 type_: _AnnotationScanType,
367) -> _AnnotationScanType: ...
368
369
370def de_optionalize_union_types(
371 type_: _AnnotationScanType,
372) -> _AnnotationScanType:
373 """Given a type, filter out ``Union`` types that include ``NoneType``
374 to not include the ``NoneType``.
375
376 """
377
378 if is_fwd_ref(type_):
379 return de_optionalize_fwd_ref_union_types(type_)
380
381 elif is_optional(type_):
382 typ = set(type_.__args__)
383
384 typ.discard(NoneType)
385 typ.discard(NoneFwd)
386
387 return make_union_type(*typ)
388
389 else:
390 return type_
391
392
393def de_optionalize_fwd_ref_union_types(
394 type_: ForwardRef,
395) -> _AnnotationScanType:
396 """return the non-optional type for Optional[], Union[None, ...], x|None,
397 etc. without de-stringifying forward refs.
398
399 unfortunately this seems to require lots of hardcoded heuristics
400
401 """
402
403 annotation = type_.__forward_arg__
404
405 mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
406 if mm:
407 if mm.group(1) == "Optional":
408 return ForwardRef(mm.group(2))
409 elif mm.group(1) == "Union":
410 elements = re.split(r",\s*", mm.group(2))
411 return make_union_type(
412 *[ForwardRef(elem) for elem in elements if elem != "None"]
413 )
414 else:
415 return type_
416
417 pipe_tokens = re.split(r"\s*\|\s*", annotation)
418 if "None" in pipe_tokens:
419 return ForwardRef("|".join(p for p in pipe_tokens if p != "None"))
420
421 return type_
422
423
424def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
425 """Make a Union type.
426
427 This is needed by :func:`.de_optionalize_union_types` which removes
428 ``NoneType`` from a ``Union``.
429
430 """
431 return cast(Any, Union).__getitem__(types) # type: ignore
432
433
434def expand_unions(
435 type_: Type[Any], include_union: bool = False, discard_none: bool = False
436) -> Tuple[Type[Any], ...]:
437 """Return a type as a tuple of individual types, expanding for
438 ``Union`` types."""
439
440 if is_union(type_):
441 typ = set(type_.__args__)
442
443 if discard_none:
444 typ.discard(NoneType)
445
446 if include_union:
447 return (type_,) + tuple(typ) # type: ignore
448 else:
449 return tuple(typ) # type: ignore
450 else:
451 return (type_,)
452
453
454def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
455 return is_origin_of(
456 type_,
457 "Optional",
458 "Union",
459 "UnionType",
460 )
461
462
463def is_optional_union(type_: Any) -> bool:
464 return is_optional(type_) and NoneType in typing_get_args(type_)
465
466
467def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
468 return is_origin_of(type_, "Union")
469
470
471def is_origin_of_cls(
472 type_: Any, class_obj: Union[Tuple[Type[Any], ...], Type[Any]]
473) -> bool:
474 """return True if the given type has an __origin__ that shares a base
475 with the given class"""
476
477 origin = typing_get_origin(type_)
478 if origin is None:
479 return False
480
481 return isinstance(origin, type) and issubclass(origin, class_obj)
482
483
484def is_origin_of(
485 type_: Any, *names: str, module: Optional[str] = None
486) -> bool:
487 """return True if the given type has an __origin__ with the given name
488 and optional module."""
489
490 origin = typing_get_origin(type_)
491 if origin is None:
492 return False
493
494 return _get_type_name(origin) in names and (
495 module is None or origin.__module__.startswith(module)
496 )
497
498
499def _get_type_name(type_: Type[Any]) -> str:
500 if compat.py310:
501 return type_.__name__
502 else:
503 typ_name = getattr(type_, "__name__", None)
504 if typ_name is None:
505 typ_name = getattr(type_, "_name", None)
506
507 return typ_name # type: ignore
508
509
510class DescriptorProto(Protocol):
511 def __get__(self, instance: object, owner: Any) -> Any: ...
512
513 def __set__(self, instance: Any, value: Any) -> None: ...
514
515 def __delete__(self, instance: Any) -> None: ...
516
517
518_DESC = TypeVar("_DESC", bound=DescriptorProto)
519
520
521class DescriptorReference(Generic[_DESC]):
522 """a descriptor that refers to a descriptor.
523
524 used for cases where we need to have an instance variable referring to an
525 object that is itself a descriptor, which typically confuses typing tools
526 as they don't know when they should use ``__get__`` or not when referring
527 to the descriptor assignment as an instance variable. See
528 sqlalchemy.orm.interfaces.PropComparator.prop
529
530 """
531
532 if TYPE_CHECKING:
533
534 def __get__(self, instance: object, owner: Any) -> _DESC: ...
535
536 def __set__(self, instance: Any, value: _DESC) -> None: ...
537
538 def __delete__(self, instance: Any) -> None: ...
539
540
541_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
542
543
544class RODescriptorReference(Generic[_DESC_co]):
545 """a descriptor that refers to a descriptor.
546
547 same as :class:`.DescriptorReference` but is read-only, so that subclasses
548 can define a subtype as the generically contained element
549
550 """
551
552 if TYPE_CHECKING:
553
554 def __get__(self, instance: object, owner: Any) -> _DESC_co: ...
555
556 def __set__(self, instance: Any, value: Any) -> NoReturn: ...
557
558 def __delete__(self, instance: Any) -> NoReturn: ...
559
560
561_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
562
563
564class CallableReference(Generic[_FN]):
565 """a descriptor that refers to a callable.
566
567 works around mypy's limitation of not allowing callables assigned
568 as instance variables
569
570
571 """
572
573 if TYPE_CHECKING:
574
575 def __get__(self, instance: object, owner: Any) -> _FN: ...
576
577 def __set__(self, instance: Any, value: _FN) -> None: ...
578
579 def __delete__(self, instance: Any) -> None: ...
580
581
582# $def ro_descriptor_reference(fn: Callable[])