1from __future__ import annotations
2
3import sys
4import types
5from collections.abc import Generator, Iterable, Iterator
6from typing import (
7 Any,
8 ClassVar,
9 NoReturn,
10 TypeVar,
11 TYPE_CHECKING,
12)
13
14import numpy as np
15
16__all__ = ["_GenericAlias", "NDArray"]
17
18_T = TypeVar("_T", bound="_GenericAlias")
19
20
21def _to_str(obj: object) -> str:
22 """Helper function for `_GenericAlias.__repr__`."""
23 if obj is Ellipsis:
24 return '...'
25 elif isinstance(obj, type) and not isinstance(obj, _GENERIC_ALIAS_TYPE):
26 if obj.__module__ == 'builtins':
27 return obj.__qualname__
28 else:
29 return f'{obj.__module__}.{obj.__qualname__}'
30 else:
31 return repr(obj)
32
33
34def _parse_parameters(args: Iterable[Any]) -> Generator[TypeVar, None, None]:
35 """Search for all typevars and typevar-containing objects in `args`.
36
37 Helper function for `_GenericAlias.__init__`.
38
39 """
40 for i in args:
41 if hasattr(i, "__parameters__"):
42 yield from i.__parameters__
43 elif isinstance(i, TypeVar):
44 yield i
45
46
47def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T:
48 """Recursively replace all typevars with those from `parameters`.
49
50 Helper function for `_GenericAlias.__getitem__`.
51
52 """
53 args = []
54 for i in alias.__args__:
55 if isinstance(i, TypeVar):
56 value: Any = next(parameters)
57 elif isinstance(i, _GenericAlias):
58 value = _reconstruct_alias(i, parameters)
59 elif hasattr(i, "__parameters__"):
60 prm_tup = tuple(next(parameters) for _ in i.__parameters__)
61 value = i[prm_tup]
62 else:
63 value = i
64 args.append(value)
65
66 cls = type(alias)
67 return cls(alias.__origin__, tuple(args), alias.__unpacked__)
68
69
70class _GenericAlias:
71 """A python-based backport of the `types.GenericAlias` class.
72
73 E.g. for ``t = list[int]``, ``t.__origin__`` is ``list`` and
74 ``t.__args__`` is ``(int,)``.
75
76 See Also
77 --------
78 :pep:`585`
79 The PEP responsible for introducing `types.GenericAlias`.
80
81 """
82
83 __slots__ = (
84 "__weakref__",
85 "_origin",
86 "_args",
87 "_parameters",
88 "_hash",
89 "_starred",
90 )
91
92 @property
93 def __origin__(self) -> type:
94 return super().__getattribute__("_origin")
95
96 @property
97 def __args__(self) -> tuple[object, ...]:
98 return super().__getattribute__("_args")
99
100 @property
101 def __parameters__(self) -> tuple[TypeVar, ...]:
102 """Type variables in the ``GenericAlias``."""
103 return super().__getattribute__("_parameters")
104
105 @property
106 def __unpacked__(self) -> bool:
107 return super().__getattribute__("_starred")
108
109 @property
110 def __typing_unpacked_tuple_args__(self) -> tuple[object, ...] | None:
111 # NOTE: This should return `__args__` if `__origin__` is a tuple,
112 # which should never be the case with how `_GenericAlias` is used
113 # within numpy
114 return None
115
116 def __init__(
117 self,
118 origin: type,
119 args: object | tuple[object, ...],
120 starred: bool = False,
121 ) -> None:
122 self._origin = origin
123 self._args = args if isinstance(args, tuple) else (args,)
124 self._parameters = tuple(_parse_parameters(self.__args__))
125 self._starred = starred
126
127 @property
128 def __call__(self) -> type[Any]:
129 return self.__origin__
130
131 def __reduce__(self: _T) -> tuple[
132 type[_T],
133 tuple[type[Any], tuple[object, ...], bool],
134 ]:
135 cls = type(self)
136 return cls, (self.__origin__, self.__args__, self.__unpacked__)
137
138 def __mro_entries__(self, bases: Iterable[object]) -> tuple[type[Any]]:
139 return (self.__origin__,)
140
141 def __dir__(self) -> list[str]:
142 """Implement ``dir(self)``."""
143 cls = type(self)
144 dir_origin = set(dir(self.__origin__))
145 return sorted(cls._ATTR_EXCEPTIONS | dir_origin)
146
147 def __hash__(self) -> int:
148 """Return ``hash(self)``."""
149 # Attempt to use the cached hash
150 try:
151 return super().__getattribute__("_hash")
152 except AttributeError:
153 self._hash: int = (
154 hash(self.__origin__) ^
155 hash(self.__args__) ^
156 hash(self.__unpacked__)
157 )
158 return super().__getattribute__("_hash")
159
160 def __instancecheck__(self, obj: object) -> NoReturn:
161 """Check if an `obj` is an instance."""
162 raise TypeError("isinstance() argument 2 cannot be a "
163 "parameterized generic")
164
165 def __subclasscheck__(self, cls: type) -> NoReturn:
166 """Check if a `cls` is a subclass."""
167 raise TypeError("issubclass() argument 2 cannot be a "
168 "parameterized generic")
169
170 def __repr__(self) -> str:
171 """Return ``repr(self)``."""
172 args = ", ".join(_to_str(i) for i in self.__args__)
173 origin = _to_str(self.__origin__)
174 prefix = "*" if self.__unpacked__ else ""
175 return f"{prefix}{origin}[{args}]"
176
177 def __getitem__(self: _T, key: object | tuple[object, ...]) -> _T:
178 """Return ``self[key]``."""
179 key_tup = key if isinstance(key, tuple) else (key,)
180
181 if len(self.__parameters__) == 0:
182 raise TypeError(f"There are no type variables left in {self}")
183 elif len(key_tup) > len(self.__parameters__):
184 raise TypeError(f"Too many arguments for {self}")
185 elif len(key_tup) < len(self.__parameters__):
186 raise TypeError(f"Too few arguments for {self}")
187
188 key_iter = iter(key_tup)
189 return _reconstruct_alias(self, key_iter)
190
191 def __eq__(self, value: object) -> bool:
192 """Return ``self == value``."""
193 if not isinstance(value, _GENERIC_ALIAS_TYPE):
194 return NotImplemented
195 return (
196 self.__origin__ == value.__origin__ and
197 self.__args__ == value.__args__ and
198 self.__unpacked__ == getattr(
199 value, "__unpacked__", self.__unpacked__
200 )
201 )
202
203 def __iter__(self: _T) -> Generator[_T, None, None]:
204 """Return ``iter(self)``."""
205 cls = type(self)
206 yield cls(self.__origin__, self.__args__, True)
207
208 _ATTR_EXCEPTIONS: ClassVar[frozenset[str]] = frozenset({
209 "__origin__",
210 "__args__",
211 "__parameters__",
212 "__mro_entries__",
213 "__reduce__",
214 "__reduce_ex__",
215 "__copy__",
216 "__deepcopy__",
217 "__unpacked__",
218 "__typing_unpacked_tuple_args__",
219 "__class__",
220 })
221
222 def __getattribute__(self, name: str) -> Any:
223 """Return ``getattr(self, name)``."""
224 # Pull the attribute from `__origin__` unless its
225 # name is in `_ATTR_EXCEPTIONS`
226 cls = type(self)
227 if name in cls._ATTR_EXCEPTIONS:
228 return super().__getattribute__(name)
229 return getattr(self.__origin__, name)
230
231
232# See `_GenericAlias.__eq__`
233if sys.version_info >= (3, 9):
234 _GENERIC_ALIAS_TYPE = (_GenericAlias, types.GenericAlias)
235else:
236 _GENERIC_ALIAS_TYPE = (_GenericAlias,)
237
238ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
239
240if TYPE_CHECKING or sys.version_info >= (3, 9):
241 _DType = np.dtype[ScalarType]
242 NDArray = np.ndarray[Any, np.dtype[ScalarType]]
243else:
244 _DType = _GenericAlias(np.dtype, (ScalarType,))
245 NDArray = _GenericAlias(np.ndarray, (Any, _DType))