1"""
2Array methods which are called by both the C-code for the method
3and the Python code for the NumPy-namespace function
4
5"""
6import os
7import pickle
8import warnings
9from contextlib import nullcontext
10
11import numpy as np
12from numpy._core import multiarray as mu
13from numpy._core import umath as um
14from numpy._core.multiarray import asanyarray
15from numpy._core import numerictypes as nt
16from numpy._core import _exceptions
17from numpy._globals import _NoValue
18
19# save those O(100) nanoseconds!
20bool_dt = mu.dtype("bool")
21umr_maximum = um.maximum.reduce
22umr_minimum = um.minimum.reduce
23umr_sum = um.add.reduce
24umr_prod = um.multiply.reduce
25umr_bitwise_count = um.bitwise_count
26umr_any = um.logical_or.reduce
27umr_all = um.logical_and.reduce
28
29# Complex types to -> (2,)float view for fast-path computation in _var()
30_complex_to_float = {
31 nt.dtype(nt.csingle) : nt.dtype(nt.single),
32 nt.dtype(nt.cdouble) : nt.dtype(nt.double),
33}
34# Special case for windows: ensure double takes precedence
35if nt.dtype(nt.longdouble) != nt.dtype(nt.double):
36 _complex_to_float.update({
37 nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble),
38 })
39
40# avoid keyword arguments to speed up parsing, saves about 15%-20% for very
41# small reductions
42def _amax(a, axis=None, out=None, keepdims=False,
43 initial=_NoValue, where=True):
44 return umr_maximum(a, axis, None, out, keepdims, initial, where)
45
46def _amin(a, axis=None, out=None, keepdims=False,
47 initial=_NoValue, where=True):
48 return umr_minimum(a, axis, None, out, keepdims, initial, where)
49
50def _sum(a, axis=None, dtype=None, out=None, keepdims=False,
51 initial=_NoValue, where=True):
52 return umr_sum(a, axis, dtype, out, keepdims, initial, where)
53
54def _prod(a, axis=None, dtype=None, out=None, keepdims=False,
55 initial=_NoValue, where=True):
56 return umr_prod(a, axis, dtype, out, keepdims, initial, where)
57
58def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
59 # By default, return a boolean for any and all
60 if dtype is None:
61 dtype = bool_dt
62 # Parsing keyword arguments is currently fairly slow, so avoid it for now
63 if where is True:
64 return umr_any(a, axis, dtype, out, keepdims)
65 return umr_any(a, axis, dtype, out, keepdims, where=where)
66
67def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
68 # By default, return a boolean for any and all
69 if dtype is None:
70 dtype = bool_dt
71 # Parsing keyword arguments is currently fairly slow, so avoid it for now
72 if where is True:
73 return umr_all(a, axis, dtype, out, keepdims)
74 return umr_all(a, axis, dtype, out, keepdims, where=where)
75
76def _count_reduce_items(arr, axis, keepdims=False, where=True):
77 # fast-path for the default case
78 if where is True:
79 # no boolean mask given, calculate items according to axis
80 if axis is None:
81 axis = tuple(range(arr.ndim))
82 elif not isinstance(axis, tuple):
83 axis = (axis,)
84 items = 1
85 for ax in axis:
86 items *= arr.shape[mu.normalize_axis_index(ax, arr.ndim)]
87 items = nt.intp(items)
88 else:
89 # TODO: Optimize case when `where` is broadcast along a non-reduction
90 # axis and full sum is more excessive than needed.
91
92 # guarded to protect circular imports
93 from numpy.lib._stride_tricks_impl import broadcast_to
94 # count True values in (potentially broadcasted) boolean mask
95 items = umr_sum(broadcast_to(where, arr.shape), axis, nt.intp, None,
96 keepdims)
97 return items
98
99def _clip(a, min=None, max=None, out=None, **kwargs):
100 if a.dtype.kind in "iu":
101 # If min/max is a Python integer, deal with out-of-bound values here.
102 # (This enforces NEP 50 rules as no value based promotion is done.)
103 if type(min) is int and min <= np.iinfo(a.dtype).min:
104 min = None
105 if type(max) is int and max >= np.iinfo(a.dtype).max:
106 max = None
107
108 if min is None and max is None:
109 # return identity
110 return um.positive(a, out=out, **kwargs)
111 elif min is None:
112 return um.minimum(a, max, out=out, **kwargs)
113 elif max is None:
114 return um.maximum(a, min, out=out, **kwargs)
115 else:
116 return um.clip(a, min, max, out=out, **kwargs)
117
118def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
119 arr = asanyarray(a)
120
121 is_float16_result = False
122
123 rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
124 if rcount == 0 if where is True else umr_any(rcount == 0, axis=None):
125 warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
126
127 # Cast bool, unsigned int, and int to float64 by default
128 if dtype is None:
129 if issubclass(arr.dtype.type, (nt.integer, nt.bool)):
130 dtype = mu.dtype('f8')
131 elif issubclass(arr.dtype.type, nt.float16):
132 dtype = mu.dtype('f4')
133 is_float16_result = True
134
135 ret = umr_sum(arr, axis, dtype, out, keepdims, where=where)
136 if isinstance(ret, mu.ndarray):
137 ret = um.true_divide(
138 ret, rcount, out=ret, casting='unsafe', subok=False)
139 if is_float16_result and out is None:
140 ret = arr.dtype.type(ret)
141 elif hasattr(ret, 'dtype'):
142 if is_float16_result:
143 ret = arr.dtype.type(ret / rcount)
144 else:
145 ret = ret.dtype.type(ret / rcount)
146 else:
147 ret = ret / rcount
148
149 return ret
150
151def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
152 where=True, mean=None):
153 arr = asanyarray(a)
154
155 rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
156 # Make this warning show up on top.
157 if ddof >= rcount if where is True else umr_any(ddof >= rcount, axis=None):
158 warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning,
159 stacklevel=2)
160
161 # Cast bool, unsigned int, and int to float64 by default
162 if dtype is None and issubclass(arr.dtype.type, (nt.integer, nt.bool)):
163 dtype = mu.dtype('f8')
164
165 if mean is not None:
166 arrmean = mean
167 else:
168 # Compute the mean.
169 # Note that if dtype is not of inexact type then arraymean will
170 # not be either.
171 arrmean = umr_sum(arr, axis, dtype, keepdims=True, where=where)
172 # The shape of rcount has to match arrmean to not change the shape of
173 # out in broadcasting. Otherwise, it cannot be stored back to arrmean.
174 if rcount.ndim == 0:
175 # fast-path for default case when where is True
176 div = rcount
177 else:
178 # matching rcount to arrmean when where is specified as array
179 div = rcount.reshape(arrmean.shape)
180 if isinstance(arrmean, mu.ndarray):
181 arrmean = um.true_divide(arrmean, div, out=arrmean,
182 casting='unsafe', subok=False)
183 elif hasattr(arrmean, "dtype"):
184 arrmean = arrmean.dtype.type(arrmean / rcount)
185 else:
186 arrmean = arrmean / rcount
187
188 # Compute sum of squared deviations from mean
189 # Note that x may not be inexact and that we need it to be an array,
190 # not a scalar.
191 x = asanyarray(arr - arrmean)
192
193 if issubclass(arr.dtype.type, (nt.floating, nt.integer)):
194 x = um.multiply(x, x, out=x)
195 # Fast-paths for built-in complex types
196 elif x.dtype in _complex_to_float:
197 xv = x.view(dtype=(_complex_to_float[x.dtype], (2,)))
198 um.multiply(xv, xv, out=xv)
199 x = um.add(xv[..., 0], xv[..., 1], out=x.real).real
200 # Most general case; includes handling object arrays containing imaginary
201 # numbers and complex types with non-native byteorder
202 else:
203 x = um.multiply(x, um.conjugate(x), out=x).real
204
205 ret = umr_sum(x, axis, dtype, out, keepdims=keepdims, where=where)
206
207 # Compute degrees of freedom and make sure it is not negative.
208 rcount = um.maximum(rcount - ddof, 0)
209
210 # divide by degrees of freedom
211 if isinstance(ret, mu.ndarray):
212 ret = um.true_divide(
213 ret, rcount, out=ret, casting='unsafe', subok=False)
214 elif hasattr(ret, 'dtype'):
215 ret = ret.dtype.type(ret / rcount)
216 else:
217 ret = ret / rcount
218
219 return ret
220
221def _std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
222 where=True, mean=None):
223 ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
224 keepdims=keepdims, where=where, mean=mean)
225
226 if isinstance(ret, mu.ndarray):
227 ret = um.sqrt(ret, out=ret)
228 elif hasattr(ret, 'dtype'):
229 ret = ret.dtype.type(um.sqrt(ret))
230 else:
231 ret = um.sqrt(ret)
232
233 return ret
234
235def _ptp(a, axis=None, out=None, keepdims=False):
236 return um.subtract(
237 umr_maximum(a, axis, None, out, keepdims),
238 umr_minimum(a, axis, None, None, keepdims),
239 out
240 )
241
242def _dump(self, file, protocol=2):
243 if hasattr(file, 'write'):
244 ctx = nullcontext(file)
245 else:
246 ctx = open(os.fspath(file), "wb")
247 with ctx as f:
248 pickle.dump(self, f, protocol=protocol)
249
250def _dumps(self, protocol=2):
251 return pickle.dumps(self, protocol=protocol)
252
253def _bitwise_count(a, out=None, *, where=True, casting='same_kind',
254 order='K', dtype=None, subok=True):
255 return umr_bitwise_count(a, out, where=where, casting=casting,
256 order=order, dtype=dtype, subok=subok)