1from __future__ import annotations
2
3import sys
4from collections.abc import Collection, Callable, Sequence
5from typing import Any, Protocol, TypeAlias, TypeVar, runtime_checkable, TYPE_CHECKING
6
7import numpy as np
8from numpy import (
9 ndarray,
10 dtype,
11 generic,
12 unsignedinteger,
13 integer,
14 floating,
15 complexfloating,
16 number,
17 timedelta64,
18 datetime64,
19 object_,
20 void,
21 str_,
22 bytes_,
23)
24from ._nested_sequence import _NestedSequence
25from ._shape import _Shape
26
27if TYPE_CHECKING:
28 StringDType = np.dtypes.StringDType
29else:
30 # at runtime outside of type checking importing this from numpy.dtypes
31 # would lead to a circular import
32 from numpy._core.multiarray import StringDType
33
34_T = TypeVar("_T")
35_ScalarType = TypeVar("_ScalarType", bound=generic)
36_ScalarType_co = TypeVar("_ScalarType_co", bound=generic, covariant=True)
37_DType = TypeVar("_DType", bound=dtype[Any])
38_DType_co = TypeVar("_DType_co", covariant=True, bound=dtype[Any])
39
40NDArray: TypeAlias = ndarray[_Shape, dtype[_ScalarType_co]]
41
42# The `_SupportsArray` protocol only cares about the default dtype
43# (i.e. `dtype=None` or no `dtype` parameter at all) of the to-be returned
44# array.
45# Concrete implementations of the protocol are responsible for adding
46# any and all remaining overloads
47@runtime_checkable
48class _SupportsArray(Protocol[_DType_co]):
49 def __array__(self) -> ndarray[Any, _DType_co]: ...
50
51
52@runtime_checkable
53class _SupportsArrayFunc(Protocol):
54 """A protocol class representing `~class.__array_function__`."""
55 def __array_function__(
56 self,
57 func: Callable[..., Any],
58 types: Collection[type[Any]],
59 args: tuple[Any, ...],
60 kwargs: dict[str, Any],
61 ) -> object: ...
62
63
64# TODO: Wait until mypy supports recursive objects in combination with typevars
65_FiniteNestedSequence: TypeAlias = (
66 _T
67 | Sequence[_T]
68 | Sequence[Sequence[_T]]
69 | Sequence[Sequence[Sequence[_T]]]
70 | Sequence[Sequence[Sequence[Sequence[_T]]]]
71)
72
73# A subset of `npt.ArrayLike` that can be parametrized w.r.t. `np.generic`
74_ArrayLike: TypeAlias = (
75 _SupportsArray[dtype[_ScalarType]]
76 | _NestedSequence[_SupportsArray[dtype[_ScalarType]]]
77)
78
79# A union representing array-like objects; consists of two typevars:
80# One representing types that can be parametrized w.r.t. `np.dtype`
81# and another one for the rest
82_DualArrayLike: TypeAlias = (
83 _SupportsArray[_DType]
84 | _NestedSequence[_SupportsArray[_DType]]
85 | _T
86 | _NestedSequence[_T]
87)
88
89if sys.version_info >= (3, 12):
90 from collections.abc import Buffer
91
92 ArrayLike: TypeAlias = Buffer | _DualArrayLike[
93 dtype[Any],
94 bool | int | float | complex | str | bytes,
95 ]
96else:
97 ArrayLike: TypeAlias = _DualArrayLike[
98 dtype[Any],
99 bool | int | float | complex | str | bytes,
100 ]
101
102# `ArrayLike<X>_co`: array-like objects that can be coerced into `X`
103# given the casting rules `same_kind`
104_ArrayLikeBool_co: TypeAlias = _DualArrayLike[
105 dtype[np.bool],
106 bool,
107]
108_ArrayLikeUInt_co: TypeAlias = _DualArrayLike[
109 dtype[np.bool] | dtype[unsignedinteger[Any]],
110 bool,
111]
112_ArrayLikeInt_co: TypeAlias = _DualArrayLike[
113 dtype[np.bool] | dtype[integer[Any]],
114 bool | int,
115]
116_ArrayLikeFloat_co: TypeAlias = _DualArrayLike[
117 dtype[np.bool] | dtype[integer[Any]] | dtype[floating[Any]],
118 bool | int | float,
119]
120_ArrayLikeComplex_co: TypeAlias = _DualArrayLike[
121 (
122 dtype[np.bool]
123 | dtype[integer[Any]]
124 | dtype[floating[Any]]
125 | dtype[complexfloating[Any, Any]]
126 ),
127 bool | int | float | complex,
128]
129_ArrayLikeNumber_co: TypeAlias = _DualArrayLike[
130 dtype[np.bool] | dtype[number[Any]],
131 bool | int | float | complex,
132]
133_ArrayLikeTD64_co: TypeAlias = _DualArrayLike[
134 dtype[np.bool] | dtype[integer[Any]] | dtype[timedelta64],
135 bool | int,
136]
137_ArrayLikeDT64_co: TypeAlias = (
138 _SupportsArray[dtype[datetime64]]
139 | _NestedSequence[_SupportsArray[dtype[datetime64]]]
140)
141_ArrayLikeObject_co: TypeAlias = (
142 _SupportsArray[dtype[object_]]
143 | _NestedSequence[_SupportsArray[dtype[object_]]]
144)
145
146_ArrayLikeVoid_co: TypeAlias = (
147 _SupportsArray[dtype[void]]
148 | _NestedSequence[_SupportsArray[dtype[void]]]
149)
150_ArrayLikeStr_co: TypeAlias = _DualArrayLike[
151 dtype[str_],
152 str,
153]
154_ArrayLikeBytes_co: TypeAlias = _DualArrayLike[
155 dtype[bytes_],
156 bytes,
157]
158_ArrayLikeString_co: TypeAlias = _DualArrayLike[
159 StringDType,
160 str
161]
162_ArrayLikeAnyString_co: TypeAlias = (
163 _ArrayLikeStr_co |
164 _ArrayLikeBytes_co |
165 _ArrayLikeString_co
166)
167
168# NOTE: This includes `builtins.bool`, but not `numpy.bool`.
169_ArrayLikeInt: TypeAlias = _DualArrayLike[
170 dtype[integer[Any]],
171 int,
172]
173
174# Extra ArrayLike type so that pyright can deal with NDArray[Any]
175# Used as the first overload, should only match NDArray[Any],
176# not any actual types.
177# https://github.com/numpy/numpy/pull/22193
178if sys.version_info >= (3, 11):
179 from typing import Never as _UnknownType
180else:
181 from typing import NoReturn as _UnknownType
182
183
184_ArrayLikeUnknown: TypeAlias = _DualArrayLike[
185 dtype[_UnknownType],
186 _UnknownType,
187]