1import sys
2from collections.abc import Callable, Collection, Sequence
3from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, runtime_checkable
4
5import numpy as np
6from numpy import dtype
7
8from ._nbit_base import _32Bit, _64Bit
9from ._nested_sequence import _NestedSequence
10from ._shape import _AnyShape
11
12if TYPE_CHECKING:
13 StringDType = np.dtypes.StringDType
14else:
15 # at runtime outside of type checking importing this from numpy.dtypes
16 # would lead to a circular import
17 from numpy._core.multiarray import StringDType
18
19_T = TypeVar("_T")
20_ScalarT = TypeVar("_ScalarT", bound=np.generic)
21_DTypeT = TypeVar("_DTypeT", bound=dtype[Any])
22_DTypeT_co = TypeVar("_DTypeT_co", covariant=True, bound=dtype[Any])
23
24NDArray: TypeAlias = np.ndarray[_AnyShape, dtype[_ScalarT]]
25
26# The `_SupportsArray` protocol only cares about the default dtype
27# (i.e. `dtype=None` or no `dtype` parameter at all) of the to-be returned
28# array.
29# Concrete implementations of the protocol are responsible for adding
30# any and all remaining overloads
31@runtime_checkable
32class _SupportsArray(Protocol[_DTypeT_co]):
33 def __array__(self) -> np.ndarray[Any, _DTypeT_co]: ...
34
35
36@runtime_checkable
37class _SupportsArrayFunc(Protocol):
38 """A protocol class representing `~class.__array_function__`."""
39 def __array_function__(
40 self,
41 func: Callable[..., Any],
42 types: Collection[type[Any]],
43 args: tuple[Any, ...],
44 kwargs: dict[str, Any],
45 ) -> object: ...
46
47
48# TODO: Wait until mypy supports recursive objects in combination with typevars
49_FiniteNestedSequence: TypeAlias = (
50 _T
51 | Sequence[_T]
52 | Sequence[Sequence[_T]]
53 | Sequence[Sequence[Sequence[_T]]]
54 | Sequence[Sequence[Sequence[Sequence[_T]]]]
55)
56
57# A subset of `npt.ArrayLike` that can be parametrized w.r.t. `np.generic`
58_ArrayLike: TypeAlias = (
59 _SupportsArray[dtype[_ScalarT]]
60 | _NestedSequence[_SupportsArray[dtype[_ScalarT]]]
61)
62
63# A union representing array-like objects; consists of two typevars:
64# One representing types that can be parametrized w.r.t. `np.dtype`
65# and another one for the rest
66_DualArrayLike: TypeAlias = (
67 _SupportsArray[_DTypeT]
68 | _NestedSequence[_SupportsArray[_DTypeT]]
69 | _T
70 | _NestedSequence[_T]
71)
72
73if sys.version_info >= (3, 12):
74 from collections.abc import Buffer as _Buffer
75else:
76 @runtime_checkable
77 class _Buffer(Protocol):
78 def __buffer__(self, flags: int, /) -> memoryview: ...
79
80ArrayLike: TypeAlias = _Buffer | _DualArrayLike[dtype[Any], complex | bytes | str]
81
82# `ArrayLike<X>_co`: array-like objects that can be coerced into `X`
83# given the casting rules `same_kind`
84_ArrayLikeBool_co: TypeAlias = _DualArrayLike[dtype[np.bool], bool]
85_ArrayLikeUInt_co: TypeAlias = _DualArrayLike[dtype[np.bool | np.unsignedinteger], bool]
86_ArrayLikeInt_co: TypeAlias = _DualArrayLike[dtype[np.bool | np.integer], int]
87_ArrayLikeFloat_co: TypeAlias = _DualArrayLike[dtype[np.bool | np.integer | np.floating], float]
88_ArrayLikeComplex_co: TypeAlias = _DualArrayLike[dtype[np.bool | np.number], complex]
89_ArrayLikeNumber_co: TypeAlias = _ArrayLikeComplex_co
90_ArrayLikeTD64_co: TypeAlias = _DualArrayLike[dtype[np.bool | np.integer | np.timedelta64], int]
91_ArrayLikeDT64_co: TypeAlias = _ArrayLike[np.datetime64]
92_ArrayLikeObject_co: TypeAlias = _ArrayLike[np.object_]
93
94_ArrayLikeVoid_co: TypeAlias = _ArrayLike[np.void]
95_ArrayLikeBytes_co: TypeAlias = _DualArrayLike[dtype[np.bytes_], bytes]
96_ArrayLikeStr_co: TypeAlias = _DualArrayLike[dtype[np.str_], str]
97_ArrayLikeString_co: TypeAlias = _DualArrayLike[StringDType, str]
98_ArrayLikeAnyString_co: TypeAlias = _DualArrayLike[dtype[np.character] | StringDType, bytes | str]
99
100__Float64_co: TypeAlias = np.floating[_64Bit] | np.float32 | np.float16 | np.integer | np.bool
101__Complex128_co: TypeAlias = np.number[_64Bit] | np.number[_32Bit] | np.float16 | np.integer | np.bool
102_ArrayLikeFloat64_co: TypeAlias = _DualArrayLike[dtype[__Float64_co], float]
103_ArrayLikeComplex128_co: TypeAlias = _DualArrayLike[dtype[__Complex128_co], complex]
104
105# NOTE: This includes `builtins.bool`, but not `numpy.bool`.
106_ArrayLikeInt: TypeAlias = _DualArrayLike[dtype[np.integer], int]