1"""
2Array methods which are called by both the C-code for the method
3and the Python code for the NumPy-namespace function
4
5"""
6import warnings
7from contextlib import nullcontext
8
9from numpy.core import multiarray as mu
10from numpy.core import umath as um
11from numpy.core.multiarray import asanyarray
12from numpy.core import numerictypes as nt
13from numpy.core import _exceptions
14from numpy.core._ufunc_config import _no_nep50_warning
15from numpy._globals import _NoValue
16from numpy.compat import pickle, os_fspath
17
18# save those O(100) nanoseconds!
19umr_maximum = um.maximum.reduce
20umr_minimum = um.minimum.reduce
21umr_sum = um.add.reduce
22umr_prod = um.multiply.reduce
23umr_any = um.logical_or.reduce
24umr_all = um.logical_and.reduce
25
26# Complex types to -> (2,)float view for fast-path computation in _var()
27_complex_to_float = {
28 nt.dtype(nt.csingle) : nt.dtype(nt.single),
29 nt.dtype(nt.cdouble) : nt.dtype(nt.double),
30}
31# Special case for windows: ensure double takes precedence
32if nt.dtype(nt.longdouble) != nt.dtype(nt.double):
33 _complex_to_float.update({
34 nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble),
35 })
36
37# avoid keyword arguments to speed up parsing, saves about 15%-20% for very
38# small reductions
39def _amax(a, axis=None, out=None, keepdims=False,
40 initial=_NoValue, where=True):
41 return umr_maximum(a, axis, None, out, keepdims, initial, where)
42
43def _amin(a, axis=None, out=None, keepdims=False,
44 initial=_NoValue, where=True):
45 return umr_minimum(a, axis, None, out, keepdims, initial, where)
46
47def _sum(a, axis=None, dtype=None, out=None, keepdims=False,
48 initial=_NoValue, where=True):
49 return umr_sum(a, axis, dtype, out, keepdims, initial, where)
50
51def _prod(a, axis=None, dtype=None, out=None, keepdims=False,
52 initial=_NoValue, where=True):
53 return umr_prod(a, axis, dtype, out, keepdims, initial, where)
54
55def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
56 # Parsing keyword arguments is currently fairly slow, so avoid it for now
57 if where is True:
58 return umr_any(a, axis, dtype, out, keepdims)
59 return umr_any(a, axis, dtype, out, keepdims, where=where)
60
61def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
62 # Parsing keyword arguments is currently fairly slow, so avoid it for now
63 if where is True:
64 return umr_all(a, axis, dtype, out, keepdims)
65 return umr_all(a, axis, dtype, out, keepdims, where=where)
66
67def _count_reduce_items(arr, axis, keepdims=False, where=True):
68 # fast-path for the default case
69 if where is True:
70 # no boolean mask given, calculate items according to axis
71 if axis is None:
72 axis = tuple(range(arr.ndim))
73 elif not isinstance(axis, tuple):
74 axis = (axis,)
75 items = 1
76 for ax in axis:
77 items *= arr.shape[mu.normalize_axis_index(ax, arr.ndim)]
78 items = nt.intp(items)
79 else:
80 # TODO: Optimize case when `where` is broadcast along a non-reduction
81 # axis and full sum is more excessive than needed.
82
83 # guarded to protect circular imports
84 from numpy.lib.stride_tricks import broadcast_to
85 # count True values in (potentially broadcasted) boolean mask
86 items = umr_sum(broadcast_to(where, arr.shape), axis, nt.intp, None,
87 keepdims)
88 return items
89
90# Numpy 1.17.0, 2019-02-24
91# Various clip behavior deprecations, marked with _clip_dep as a prefix.
92
93def _clip_dep_is_scalar_nan(a):
94 # guarded to protect circular imports
95 from numpy.core.fromnumeric import ndim
96 if ndim(a) != 0:
97 return False
98 try:
99 return um.isnan(a)
100 except TypeError:
101 return False
102
103def _clip_dep_is_byte_swapped(a):
104 if isinstance(a, mu.ndarray):
105 return not a.dtype.isnative
106 return False
107
108def _clip_dep_invoke_with_casting(ufunc, *args, out=None, casting=None, **kwargs):
109 # normal path
110 if casting is not None:
111 return ufunc(*args, out=out, casting=casting, **kwargs)
112
113 # try to deal with broken casting rules
114 try:
115 return ufunc(*args, out=out, **kwargs)
116 except _exceptions._UFuncOutputCastingError as e:
117 # Numpy 1.17.0, 2019-02-24
118 warnings.warn(
119 "Converting the output of clip from {!r} to {!r} is deprecated. "
120 "Pass `casting=\"unsafe\"` explicitly to silence this warning, or "
121 "correct the type of the variables.".format(e.from_, e.to),
122 DeprecationWarning,
123 stacklevel=2
124 )
125 return ufunc(*args, out=out, casting="unsafe", **kwargs)
126
127def _clip(a, min=None, max=None, out=None, *, casting=None, **kwargs):
128 if min is None and max is None:
129 raise ValueError("One of max or min must be given")
130
131 # Numpy 1.17.0, 2019-02-24
132 # This deprecation probably incurs a substantial slowdown for small arrays,
133 # it will be good to get rid of it.
134 if not _clip_dep_is_byte_swapped(a) and not _clip_dep_is_byte_swapped(out):
135 using_deprecated_nan = False
136 if _clip_dep_is_scalar_nan(min):
137 min = -float('inf')
138 using_deprecated_nan = True
139 if _clip_dep_is_scalar_nan(max):
140 max = float('inf')
141 using_deprecated_nan = True
142 if using_deprecated_nan:
143 warnings.warn(
144 "Passing `np.nan` to mean no clipping in np.clip has always "
145 "been unreliable, and is now deprecated. "
146 "In future, this will always return nan, like it already does "
147 "when min or max are arrays that contain nan. "
148 "To skip a bound, pass either None or an np.inf of an "
149 "appropriate sign.",
150 DeprecationWarning,
151 stacklevel=2
152 )
153
154 if min is None:
155 return _clip_dep_invoke_with_casting(
156 um.minimum, a, max, out=out, casting=casting, **kwargs)
157 elif max is None:
158 return _clip_dep_invoke_with_casting(
159 um.maximum, a, min, out=out, casting=casting, **kwargs)
160 else:
161 return _clip_dep_invoke_with_casting(
162 um.clip, a, min, max, out=out, casting=casting, **kwargs)
163
164def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
165 arr = asanyarray(a)
166
167 is_float16_result = False
168
169 rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
170 if rcount == 0 if where is True else umr_any(rcount == 0, axis=None):
171 warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
172
173 # Cast bool, unsigned int, and int to float64 by default
174 if dtype is None:
175 if issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
176 dtype = mu.dtype('f8')
177 elif issubclass(arr.dtype.type, nt.float16):
178 dtype = mu.dtype('f4')
179 is_float16_result = True
180
181 ret = umr_sum(arr, axis, dtype, out, keepdims, where=where)
182 if isinstance(ret, mu.ndarray):
183 with _no_nep50_warning():
184 ret = um.true_divide(
185 ret, rcount, out=ret, casting='unsafe', subok=False)
186 if is_float16_result and out is None:
187 ret = arr.dtype.type(ret)
188 elif hasattr(ret, 'dtype'):
189 if is_float16_result:
190 ret = arr.dtype.type(ret / rcount)
191 else:
192 ret = ret.dtype.type(ret / rcount)
193 else:
194 ret = ret / rcount
195
196 return ret
197
198def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
199 where=True):
200 arr = asanyarray(a)
201
202 rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
203 # Make this warning show up on top.
204 if ddof >= rcount if where is True else umr_any(ddof >= rcount, axis=None):
205 warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning,
206 stacklevel=2)
207
208 # Cast bool, unsigned int, and int to float64 by default
209 if dtype is None and issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
210 dtype = mu.dtype('f8')
211
212 # Compute the mean.
213 # Note that if dtype is not of inexact type then arraymean will
214 # not be either.
215 arrmean = umr_sum(arr, axis, dtype, keepdims=True, where=where)
216 # The shape of rcount has to match arrmean to not change the shape of out
217 # in broadcasting. Otherwise, it cannot be stored back to arrmean.
218 if rcount.ndim == 0:
219 # fast-path for default case when where is True
220 div = rcount
221 else:
222 # matching rcount to arrmean when where is specified as array
223 div = rcount.reshape(arrmean.shape)
224 if isinstance(arrmean, mu.ndarray):
225 with _no_nep50_warning():
226 arrmean = um.true_divide(arrmean, div, out=arrmean,
227 casting='unsafe', subok=False)
228 elif hasattr(arrmean, "dtype"):
229 arrmean = arrmean.dtype.type(arrmean / rcount)
230 else:
231 arrmean = arrmean / rcount
232
233 # Compute sum of squared deviations from mean
234 # Note that x may not be inexact and that we need it to be an array,
235 # not a scalar.
236 x = asanyarray(arr - arrmean)
237
238 if issubclass(arr.dtype.type, (nt.floating, nt.integer)):
239 x = um.multiply(x, x, out=x)
240 # Fast-paths for built-in complex types
241 elif x.dtype in _complex_to_float:
242 xv = x.view(dtype=(_complex_to_float[x.dtype], (2,)))
243 um.multiply(xv, xv, out=xv)
244 x = um.add(xv[..., 0], xv[..., 1], out=x.real).real
245 # Most general case; includes handling object arrays containing imaginary
246 # numbers and complex types with non-native byteorder
247 else:
248 x = um.multiply(x, um.conjugate(x), out=x).real
249
250 ret = umr_sum(x, axis, dtype, out, keepdims=keepdims, where=where)
251
252 # Compute degrees of freedom and make sure it is not negative.
253 rcount = um.maximum(rcount - ddof, 0)
254
255 # divide by degrees of freedom
256 if isinstance(ret, mu.ndarray):
257 with _no_nep50_warning():
258 ret = um.true_divide(
259 ret, rcount, out=ret, casting='unsafe', subok=False)
260 elif hasattr(ret, 'dtype'):
261 ret = ret.dtype.type(ret / rcount)
262 else:
263 ret = ret / rcount
264
265 return ret
266
267def _std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
268 where=True):
269 ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
270 keepdims=keepdims, where=where)
271
272 if isinstance(ret, mu.ndarray):
273 ret = um.sqrt(ret, out=ret)
274 elif hasattr(ret, 'dtype'):
275 ret = ret.dtype.type(um.sqrt(ret))
276 else:
277 ret = um.sqrt(ret)
278
279 return ret
280
281def _ptp(a, axis=None, out=None, keepdims=False):
282 return um.subtract(
283 umr_maximum(a, axis, None, out, keepdims),
284 umr_minimum(a, axis, None, None, keepdims),
285 out
286 )
287
288def _dump(self, file, protocol=2):
289 if hasattr(file, 'write'):
290 ctx = nullcontext(file)
291 else:
292 ctx = open(os_fspath(file), "wb")
293 with ctx as f:
294 pickle.dump(self, f, protocol=protocol)
295
296def _dumps(self, protocol=2):
297 return pickle.dumps(self, protocol=protocol)