1"""
2Array methods which are called by both the C-code for the method
3and the Python code for the NumPy-namespace function
4
5"""
6import warnings
7from contextlib import nullcontext
8
9from numpy.core import multiarray as mu
10from numpy.core import umath as um
11from numpy.core.multiarray import asanyarray
12from numpy.core import numerictypes as nt
13from numpy.core import _exceptions
14from numpy.core._ufunc_config import _no_nep50_warning
15from numpy._globals import _NoValue
16from numpy.compat import pickle, os_fspath
17
18# save those O(100) nanoseconds!
19umr_maximum = um.maximum.reduce
20umr_minimum = um.minimum.reduce
21umr_sum = um.add.reduce
22umr_prod = um.multiply.reduce
23umr_any = um.logical_or.reduce
24umr_all = um.logical_and.reduce
25
26# Complex types to -> (2,)float view for fast-path computation in _var()
27_complex_to_float = {
28 nt.dtype(nt.csingle) : nt.dtype(nt.single),
29 nt.dtype(nt.cdouble) : nt.dtype(nt.double),
30}
31# Special case for windows: ensure double takes precedence
32if nt.dtype(nt.longdouble) != nt.dtype(nt.double):
33 _complex_to_float.update({
34 nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble),
35 })
36
37# avoid keyword arguments to speed up parsing, saves about 15%-20% for very
38# small reductions
39def _amax(a, axis=None, out=None, keepdims=False,
40 initial=_NoValue, where=True):
41 return umr_maximum(a, axis, None, out, keepdims, initial, where)
42
43def _amin(a, axis=None, out=None, keepdims=False,
44 initial=_NoValue, where=True):
45 return umr_minimum(a, axis, None, out, keepdims, initial, where)
46
47def _sum(a, axis=None, dtype=None, out=None, keepdims=False,
48 initial=_NoValue, where=True):
49 return umr_sum(a, axis, dtype, out, keepdims, initial, where)
50
51def _prod(a, axis=None, dtype=None, out=None, keepdims=False,
52 initial=_NoValue, where=True):
53 return umr_prod(a, axis, dtype, out, keepdims, initial, where)
54
55def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
56 # Parsing keyword arguments is currently fairly slow, so avoid it for now
57 if where is True:
58 return umr_any(a, axis, dtype, out, keepdims)
59 return umr_any(a, axis, dtype, out, keepdims, where=where)
60
61def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
62 # Parsing keyword arguments is currently fairly slow, so avoid it for now
63 if where is True:
64 return umr_all(a, axis, dtype, out, keepdims)
65 return umr_all(a, axis, dtype, out, keepdims, where=where)
66
67def _count_reduce_items(arr, axis, keepdims=False, where=True):
68 # fast-path for the default case
69 if where is True:
70 # no boolean mask given, calculate items according to axis
71 if axis is None:
72 axis = tuple(range(arr.ndim))
73 elif not isinstance(axis, tuple):
74 axis = (axis,)
75 items = 1
76 for ax in axis:
77 items *= arr.shape[mu.normalize_axis_index(ax, arr.ndim)]
78 items = nt.intp(items)
79 else:
80 # TODO: Optimize case when `where` is broadcast along a non-reduction
81 # axis and full sum is more excessive than needed.
82
83 # guarded to protect circular imports
84 from numpy.lib.stride_tricks import broadcast_to
85 # count True values in (potentially broadcasted) boolean mask
86 items = umr_sum(broadcast_to(where, arr.shape), axis, nt.intp, None,
87 keepdims)
88 return items
89
90def _clip(a, min=None, max=None, out=None, **kwargs):
91 if min is None and max is None:
92 raise ValueError("One of max or min must be given")
93
94 if min is None:
95 return um.minimum(a, max, out=out, **kwargs)
96 elif max is None:
97 return um.maximum(a, min, out=out, **kwargs)
98 else:
99 return um.clip(a, min, max, out=out, **kwargs)
100
101def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
102 arr = asanyarray(a)
103
104 is_float16_result = False
105
106 rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
107 if rcount == 0 if where is True else umr_any(rcount == 0, axis=None):
108 warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
109
110 # Cast bool, unsigned int, and int to float64 by default
111 if dtype is None:
112 if issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
113 dtype = mu.dtype('f8')
114 elif issubclass(arr.dtype.type, nt.float16):
115 dtype = mu.dtype('f4')
116 is_float16_result = True
117
118 ret = umr_sum(arr, axis, dtype, out, keepdims, where=where)
119 if isinstance(ret, mu.ndarray):
120 with _no_nep50_warning():
121 ret = um.true_divide(
122 ret, rcount, out=ret, casting='unsafe', subok=False)
123 if is_float16_result and out is None:
124 ret = arr.dtype.type(ret)
125 elif hasattr(ret, 'dtype'):
126 if is_float16_result:
127 ret = arr.dtype.type(ret / rcount)
128 else:
129 ret = ret.dtype.type(ret / rcount)
130 else:
131 ret = ret / rcount
132
133 return ret
134
135def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
136 where=True):
137 arr = asanyarray(a)
138
139 rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
140 # Make this warning show up on top.
141 if ddof >= rcount if where is True else umr_any(ddof >= rcount, axis=None):
142 warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning,
143 stacklevel=2)
144
145 # Cast bool, unsigned int, and int to float64 by default
146 if dtype is None and issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
147 dtype = mu.dtype('f8')
148
149 # Compute the mean.
150 # Note that if dtype is not of inexact type then arraymean will
151 # not be either.
152 arrmean = umr_sum(arr, axis, dtype, keepdims=True, where=where)
153 # The shape of rcount has to match arrmean to not change the shape of out
154 # in broadcasting. Otherwise, it cannot be stored back to arrmean.
155 if rcount.ndim == 0:
156 # fast-path for default case when where is True
157 div = rcount
158 else:
159 # matching rcount to arrmean when where is specified as array
160 div = rcount.reshape(arrmean.shape)
161 if isinstance(arrmean, mu.ndarray):
162 with _no_nep50_warning():
163 arrmean = um.true_divide(arrmean, div, out=arrmean,
164 casting='unsafe', subok=False)
165 elif hasattr(arrmean, "dtype"):
166 arrmean = arrmean.dtype.type(arrmean / rcount)
167 else:
168 arrmean = arrmean / rcount
169
170 # Compute sum of squared deviations from mean
171 # Note that x may not be inexact and that we need it to be an array,
172 # not a scalar.
173 x = asanyarray(arr - arrmean)
174
175 if issubclass(arr.dtype.type, (nt.floating, nt.integer)):
176 x = um.multiply(x, x, out=x)
177 # Fast-paths for built-in complex types
178 elif x.dtype in _complex_to_float:
179 xv = x.view(dtype=(_complex_to_float[x.dtype], (2,)))
180 um.multiply(xv, xv, out=xv)
181 x = um.add(xv[..., 0], xv[..., 1], out=x.real).real
182 # Most general case; includes handling object arrays containing imaginary
183 # numbers and complex types with non-native byteorder
184 else:
185 x = um.multiply(x, um.conjugate(x), out=x).real
186
187 ret = umr_sum(x, axis, dtype, out, keepdims=keepdims, where=where)
188
189 # Compute degrees of freedom and make sure it is not negative.
190 rcount = um.maximum(rcount - ddof, 0)
191
192 # divide by degrees of freedom
193 if isinstance(ret, mu.ndarray):
194 with _no_nep50_warning():
195 ret = um.true_divide(
196 ret, rcount, out=ret, casting='unsafe', subok=False)
197 elif hasattr(ret, 'dtype'):
198 ret = ret.dtype.type(ret / rcount)
199 else:
200 ret = ret / rcount
201
202 return ret
203
204def _std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
205 where=True):
206 ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
207 keepdims=keepdims, where=where)
208
209 if isinstance(ret, mu.ndarray):
210 ret = um.sqrt(ret, out=ret)
211 elif hasattr(ret, 'dtype'):
212 ret = ret.dtype.type(um.sqrt(ret))
213 else:
214 ret = um.sqrt(ret)
215
216 return ret
217
218def _ptp(a, axis=None, out=None, keepdims=False):
219 return um.subtract(
220 umr_maximum(a, axis, None, out, keepdims),
221 umr_minimum(a, axis, None, None, keepdims),
222 out
223 )
224
225def _dump(self, file, protocol=2):
226 if hasattr(file, 'write'):
227 ctx = nullcontext(file)
228 else:
229 ctx = open(os_fspath(file), "wb")
230 with ctx as f:
231 pickle.dump(self, f, protocol=protocol)
232
233def _dumps(self, protocol=2):
234 return pickle.dumps(self, protocol=protocol)