1from __future__ import annotations
2
3# NOTE: Import `Sequence` from `typing` as we it is needed for a type-alias,
4# not an annotation
5import sys
6from collections.abc import Collection, Callable
7from typing import Any, Sequence, Protocol, Union, TypeVar, runtime_checkable
8from numpy import (
9 ndarray,
10 dtype,
11 generic,
12 bool_,
13 unsignedinteger,
14 integer,
15 floating,
16 complexfloating,
17 number,
18 timedelta64,
19 datetime64,
20 object_,
21 void,
22 str_,
23 bytes_,
24)
25from ._nested_sequence import _NestedSequence
26
27_T = TypeVar("_T")
28_ScalarType = TypeVar("_ScalarType", bound=generic)
29_DType = TypeVar("_DType", bound="dtype[Any]")
30_DType_co = TypeVar("_DType_co", covariant=True, bound="dtype[Any]")
31
32# The `_SupportsArray` protocol only cares about the default dtype
33# (i.e. `dtype=None` or no `dtype` parameter at all) of the to-be returned
34# array.
35# Concrete implementations of the protocol are responsible for adding
36# any and all remaining overloads
37@runtime_checkable
38class _SupportsArray(Protocol[_DType_co]):
39 def __array__(self) -> ndarray[Any, _DType_co]: ...
40
41
42@runtime_checkable
43class _SupportsArrayFunc(Protocol):
44 """A protocol class representing `~class.__array_function__`."""
45 def __array_function__(
46 self,
47 func: Callable[..., Any],
48 types: Collection[type[Any]],
49 args: tuple[Any, ...],
50 kwargs: dict[str, Any],
51 ) -> object: ...
52
53
54# TODO: Wait until mypy supports recursive objects in combination with typevars
55_FiniteNestedSequence = Union[
56 _T,
57 Sequence[_T],
58 Sequence[Sequence[_T]],
59 Sequence[Sequence[Sequence[_T]]],
60 Sequence[Sequence[Sequence[Sequence[_T]]]],
61]
62
63# A subset of `npt.ArrayLike` that can be parametrized w.r.t. `np.generic`
64_ArrayLike = Union[
65 _SupportsArray["dtype[_ScalarType]"],
66 _NestedSequence[_SupportsArray["dtype[_ScalarType]"]],
67]
68
69# A union representing array-like objects; consists of two typevars:
70# One representing types that can be parametrized w.r.t. `np.dtype`
71# and another one for the rest
72_DualArrayLike = Union[
73 _SupportsArray[_DType],
74 _NestedSequence[_SupportsArray[_DType]],
75 _T,
76 _NestedSequence[_T],
77]
78
79# TODO: support buffer protocols once
80#
81# https://bugs.python.org/issue27501
82#
83# is resolved. See also the mypy issue:
84#
85# https://github.com/python/typing/issues/593
86if sys.version_info[:2] < (3, 9):
87 ArrayLike = _DualArrayLike[
88 dtype,
89 Union[bool, int, float, complex, str, bytes],
90 ]
91else:
92 ArrayLike = _DualArrayLike[
93 dtype[Any],
94 Union[bool, int, float, complex, str, bytes],
95 ]
96
97# `ArrayLike<X>_co`: array-like objects that can be coerced into `X`
98# given the casting rules `same_kind`
99_ArrayLikeBool_co = _DualArrayLike[
100 "dtype[bool_]",
101 bool,
102]
103_ArrayLikeUInt_co = _DualArrayLike[
104 "dtype[Union[bool_, unsignedinteger[Any]]]",
105 bool,
106]
107_ArrayLikeInt_co = _DualArrayLike[
108 "dtype[Union[bool_, integer[Any]]]",
109 Union[bool, int],
110]
111_ArrayLikeFloat_co = _DualArrayLike[
112 "dtype[Union[bool_, integer[Any], floating[Any]]]",
113 Union[bool, int, float],
114]
115_ArrayLikeComplex_co = _DualArrayLike[
116 "dtype[Union[bool_, integer[Any], floating[Any], complexfloating[Any, Any]]]",
117 Union[bool, int, float, complex],
118]
119_ArrayLikeNumber_co = _DualArrayLike[
120 "dtype[Union[bool_, number[Any]]]",
121 Union[bool, int, float, complex],
122]
123_ArrayLikeTD64_co = _DualArrayLike[
124 "dtype[Union[bool_, integer[Any], timedelta64]]",
125 Union[bool, int],
126]
127_ArrayLikeDT64_co = Union[
128 _SupportsArray["dtype[datetime64]"],
129 _NestedSequence[_SupportsArray["dtype[datetime64]"]],
130]
131_ArrayLikeObject_co = Union[
132 _SupportsArray["dtype[object_]"],
133 _NestedSequence[_SupportsArray["dtype[object_]"]],
134]
135
136_ArrayLikeVoid_co = Union[
137 _SupportsArray["dtype[void]"],
138 _NestedSequence[_SupportsArray["dtype[void]"]],
139]
140_ArrayLikeStr_co = _DualArrayLike[
141 "dtype[str_]",
142 str,
143]
144_ArrayLikeBytes_co = _DualArrayLike[
145 "dtype[bytes_]",
146 bytes,
147]
148
149_ArrayLikeInt = _DualArrayLike[
150 "dtype[integer[Any]]",
151 int,
152]
153
154# Extra ArrayLike type so that pyright can deal with NDArray[Any]
155# Used as the first overload, should only match NDArray[Any],
156# not any actual types.
157# https://github.com/numpy/numpy/pull/22193
158class _UnknownType:
159 ...
160
161
162_ArrayLikeUnknown = _DualArrayLike[
163 "dtype[_UnknownType]",
164 _UnknownType,
165]