1from __future__ import annotations
2
3import sys
4from collections.abc import Collection, Callable, Sequence
5from typing import Any, Protocol, TypeAlias, TypeVar, runtime_checkable, TYPE_CHECKING
6
7import numpy as np
8from numpy import (
9 ndarray,
10 dtype,
11 generic,
12 unsignedinteger,
13 integer,
14 floating,
15 complexfloating,
16 number,
17 timedelta64,
18 datetime64,
19 object_,
20 void,
21 str_,
22 bytes_,
23)
24from ._nbit_base import _32Bit, _64Bit
25from ._nested_sequence import _NestedSequence
26from ._shape import _Shape
27
28if TYPE_CHECKING:
29 StringDType = np.dtypes.StringDType
30else:
31 # at runtime outside of type checking importing this from numpy.dtypes
32 # would lead to a circular import
33 from numpy._core.multiarray import StringDType
34
35_T = TypeVar("_T")
36_ScalarType = TypeVar("_ScalarType", bound=generic)
37_ScalarType_co = TypeVar("_ScalarType_co", bound=generic, covariant=True)
38_DType = TypeVar("_DType", bound=dtype[Any])
39_DType_co = TypeVar("_DType_co", covariant=True, bound=dtype[Any])
40
41NDArray: TypeAlias = ndarray[_Shape, dtype[_ScalarType_co]]
42
43# The `_SupportsArray` protocol only cares about the default dtype
44# (i.e. `dtype=None` or no `dtype` parameter at all) of the to-be returned
45# array.
46# Concrete implementations of the protocol are responsible for adding
47# any and all remaining overloads
48@runtime_checkable
49class _SupportsArray(Protocol[_DType_co]):
50 def __array__(self) -> ndarray[Any, _DType_co]: ...
51
52
53@runtime_checkable
54class _SupportsArrayFunc(Protocol):
55 """A protocol class representing `~class.__array_function__`."""
56 def __array_function__(
57 self,
58 func: Callable[..., Any],
59 types: Collection[type[Any]],
60 args: tuple[Any, ...],
61 kwargs: dict[str, Any],
62 ) -> object: ...
63
64
65# TODO: Wait until mypy supports recursive objects in combination with typevars
66_FiniteNestedSequence: TypeAlias = (
67 _T
68 | Sequence[_T]
69 | Sequence[Sequence[_T]]
70 | Sequence[Sequence[Sequence[_T]]]
71 | Sequence[Sequence[Sequence[Sequence[_T]]]]
72)
73
74# A subset of `npt.ArrayLike` that can be parametrized w.r.t. `np.generic`
75_ArrayLike: TypeAlias = (
76 _SupportsArray[dtype[_ScalarType]]
77 | _NestedSequence[_SupportsArray[dtype[_ScalarType]]]
78)
79
80# A union representing array-like objects; consists of two typevars:
81# One representing types that can be parametrized w.r.t. `np.dtype`
82# and another one for the rest
83_DualArrayLike: TypeAlias = (
84 _SupportsArray[_DType]
85 | _NestedSequence[_SupportsArray[_DType]]
86 | _T
87 | _NestedSequence[_T]
88)
89
90if sys.version_info >= (3, 12):
91 from collections.abc import Buffer as _Buffer
92else:
93 @runtime_checkable
94 class _Buffer(Protocol):
95 def __buffer__(self, flags: int, /) -> memoryview: ...
96
97ArrayLike: TypeAlias = _Buffer | _DualArrayLike[
98 dtype[Any],
99 bool | int | float | complex | str | bytes,
100]
101
102# `ArrayLike<X>_co`: array-like objects that can be coerced into `X`
103# given the casting rules `same_kind`
104_ArrayLikeBool_co: TypeAlias = _DualArrayLike[
105 dtype[np.bool],
106 bool,
107]
108_ArrayLikeUInt_co: TypeAlias = _DualArrayLike[
109 dtype[np.bool] | dtype[unsignedinteger[Any]],
110 bool,
111]
112_ArrayLikeInt_co: TypeAlias = _DualArrayLike[
113 dtype[np.bool] | dtype[integer[Any]],
114 bool | int,
115]
116_ArrayLikeFloat_co: TypeAlias = _DualArrayLike[
117 dtype[np.bool] | dtype[integer[Any]] | dtype[floating[Any]],
118 bool | int | float,
119]
120_ArrayLikeComplex_co: TypeAlias = _DualArrayLike[
121 (
122 dtype[np.bool]
123 | dtype[integer[Any]]
124 | dtype[floating[Any]]
125 | dtype[complexfloating[Any, Any]]
126 ),
127 bool | int | float | complex,
128]
129_ArrayLikeNumber_co: TypeAlias = _DualArrayLike[
130 dtype[np.bool] | dtype[number[Any]],
131 bool | int | float | complex,
132]
133_ArrayLikeTD64_co: TypeAlias = _DualArrayLike[
134 dtype[np.bool] | dtype[integer[Any]] | dtype[timedelta64],
135 bool | int,
136]
137_ArrayLikeDT64_co: TypeAlias = (
138 _SupportsArray[dtype[datetime64]]
139 | _NestedSequence[_SupportsArray[dtype[datetime64]]]
140)
141_ArrayLikeObject_co: TypeAlias = (
142 _SupportsArray[dtype[object_]]
143 | _NestedSequence[_SupportsArray[dtype[object_]]]
144)
145
146_ArrayLikeVoid_co: TypeAlias = (
147 _SupportsArray[dtype[void]]
148 | _NestedSequence[_SupportsArray[dtype[void]]]
149)
150_ArrayLikeStr_co: TypeAlias = _DualArrayLike[
151 dtype[str_],
152 str,
153]
154_ArrayLikeBytes_co: TypeAlias = _DualArrayLike[
155 dtype[bytes_],
156 bytes,
157]
158_ArrayLikeString_co: TypeAlias = _DualArrayLike[
159 StringDType,
160 str
161]
162_ArrayLikeAnyString_co: TypeAlias = (
163 _ArrayLikeStr_co |
164 _ArrayLikeBytes_co |
165 _ArrayLikeString_co
166)
167
168__Float64_co: TypeAlias = np.floating[_64Bit] | np.float32 | np.float16 | np.integer | np.bool
169__Complex128_co: TypeAlias = np.number[_64Bit] | np.number[_32Bit] | np.float16 | np.integer | np.bool
170_ArrayLikeFloat64_co: TypeAlias = _DualArrayLike[dtype[__Float64_co], float | int]
171_ArrayLikeComplex128_co: TypeAlias = _DualArrayLike[dtype[__Complex128_co], complex | float | int]
172
173# NOTE: This includes `builtins.bool`, but not `numpy.bool`.
174_ArrayLikeInt: TypeAlias = _DualArrayLike[
175 dtype[integer[Any]],
176 int,
177]
178
179# Extra ArrayLike type so that pyright can deal with NDArray[Any]
180# Used as the first overload, should only match NDArray[Any],
181# not any actual types.
182# https://github.com/numpy/numpy/pull/22193
183if sys.version_info >= (3, 11):
184 from typing import Never as _UnknownType
185else:
186 from typing import NoReturn as _UnknownType
187
188
189_ArrayLikeUnknown: TypeAlias = _DualArrayLike[
190 dtype[_UnknownType],
191 _UnknownType,
192]