Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/sparse/_sputils.py: 16%
189 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
1""" Utility functions for sparse matrix module
2"""
4import sys
5import operator
6import numpy as np
7from math import prod
8import scipy.sparse as sp
11__all__ = ['upcast', 'getdtype', 'getdata', 'isscalarlike', 'isintlike',
12 'isshape', 'issequence', 'isdense', 'ismatrix', 'get_sum_dtype']
14supported_dtypes = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc,
15 np.uintc, np.int_, np.uint, np.longlong, np.ulonglong,
16 np.float32, np.float64, np.longdouble,
17 np.complex64, np.complex128, np.clongdouble]
19_upcast_memo = {}
22def upcast(*args):
23 """Returns the nearest supported sparse dtype for the
24 combination of one or more types.
26 upcast(t0, t1, ..., tn) -> T where T is a supported dtype
28 Examples
29 --------
30 >>> from scipy.sparse._sputils import upcast
31 >>> upcast('int32')
32 <type 'numpy.int32'>
33 >>> upcast('bool')
34 <type 'numpy.bool_'>
35 >>> upcast('int32','float32')
36 <type 'numpy.float64'>
37 >>> upcast('bool',complex,float)
38 <type 'numpy.complex128'>
40 """
42 t = _upcast_memo.get(hash(args))
43 if t is not None:
44 return t
46 upcast = np.result_type(*args)
48 for t in supported_dtypes:
49 if np.can_cast(upcast, t):
50 _upcast_memo[hash(args)] = t
51 return t
53 raise TypeError(f'no supported conversion for types: {args!r}')
56def upcast_char(*args):
57 """Same as `upcast` but taking dtype.char as input (faster)."""
58 t = _upcast_memo.get(args)
59 if t is not None:
60 return t
61 t = upcast(*map(np.dtype, args))
62 _upcast_memo[args] = t
63 return t
66def upcast_scalar(dtype, scalar):
67 """Determine data type for binary operation between an array of
68 type `dtype` and a scalar.
69 """
70 return (np.array([0], dtype=dtype) * scalar).dtype
73def downcast_intp_index(arr):
74 """
75 Down-cast index array to np.intp dtype if it is of a larger dtype.
77 Raise an error if the array contains a value that is too large for
78 intp.
79 """
80 if arr.dtype.itemsize > np.dtype(np.intp).itemsize:
81 if arr.size == 0:
82 return arr.astype(np.intp)
83 maxval = arr.max()
84 minval = arr.min()
85 if maxval > np.iinfo(np.intp).max or minval < np.iinfo(np.intp).min:
86 raise ValueError("Cannot deal with arrays with indices larger "
87 "than the machine maximum address size "
88 "(e.g. 64-bit indices on 32-bit machine).")
89 return arr.astype(np.intp)
90 return arr
93def to_native(A):
94 """
95 Ensure that the data type of the NumPy array `A` has native byte order.
97 `A` must be a NumPy array. If the data type of `A` does not have native
98 byte order, a copy of `A` with a native byte order is returned. Otherwise
99 `A` is returned.
100 """
101 dt = A.dtype
102 if dt.isnative:
103 # Don't call `asarray()` if A is already native, to avoid unnecessarily
104 # creating a view of the input array.
105 return A
106 return np.asarray(A, dtype=dt.newbyteorder('native'))
109def getdtype(dtype, a=None, default=None):
110 """Function used to simplify argument processing. If 'dtype' is not
111 specified (is None), returns a.dtype; otherwise returns a np.dtype
112 object created from the specified dtype argument. If 'dtype' and 'a'
113 are both None, construct a data type out of the 'default' parameter.
114 Furthermore, 'dtype' must be in 'allowed' set.
115 """
116 # TODO is this really what we want?
117 if dtype is None:
118 try:
119 newdtype = a.dtype
120 except AttributeError as e:
121 if default is not None:
122 newdtype = np.dtype(default)
123 else:
124 raise TypeError("could not interpret data type") from e
125 else:
126 newdtype = np.dtype(dtype)
127 if newdtype == np.object_:
128 raise ValueError(
129 "object dtype is not supported by sparse matrices"
130 )
132 return newdtype
135def getdata(obj, dtype=None, copy=False) -> np.ndarray:
136 """
137 This is a wrapper of `np.array(obj, dtype=dtype, copy=copy)`
138 that will generate a warning if the result is an object array.
139 """
140 data = np.array(obj, dtype=dtype, copy=copy)
141 # Defer to getdtype for checking that the dtype is OK.
142 # This is called for the validation only; we don't need the return value.
143 getdtype(data.dtype)
144 return data
147def get_index_dtype(arrays=(), maxval=None, check_contents=False):
148 """
149 Based on input (integer) arrays `a`, determine a suitable index data
150 type that can hold the data in the arrays.
152 Parameters
153 ----------
154 arrays : tuple of array_like
155 Input arrays whose types/contents to check
156 maxval : float, optional
157 Maximum value needed
158 check_contents : bool, optional
159 Whether to check the values in the arrays and not just their types.
160 Default: False (check only the types)
162 Returns
163 -------
164 dtype : dtype
165 Suitable index data type (int32 or int64)
167 """
169 int32min = np.int32(np.iinfo(np.int32).min)
170 int32max = np.int32(np.iinfo(np.int32).max)
172 # not using intc directly due to misinteractions with pythran
173 dtype = np.int32 if np.intc().itemsize == 4 else np.int64
174 if maxval is not None:
175 maxval = np.int64(maxval)
176 if maxval > int32max:
177 dtype = np.int64
179 if isinstance(arrays, np.ndarray):
180 arrays = (arrays,)
182 for arr in arrays:
183 arr = np.asarray(arr)
184 if not np.can_cast(arr.dtype, np.int32):
185 if check_contents:
186 if arr.size == 0:
187 # a bigger type not needed
188 continue
189 elif np.issubdtype(arr.dtype, np.integer):
190 maxval = arr.max()
191 minval = arr.min()
192 if minval >= int32min and maxval <= int32max:
193 # a bigger type not needed
194 continue
196 dtype = np.int64
197 break
199 return dtype
202def get_sum_dtype(dtype):
203 """Mimic numpy's casting for np.sum"""
204 if dtype.kind == 'u' and np.can_cast(dtype, np.uint):
205 return np.uint
206 if np.can_cast(dtype, np.int_):
207 return np.int_
208 return dtype
211def isscalarlike(x) -> bool:
212 """Is x either a scalar, an array scalar, or a 0-dim array?"""
213 return np.isscalar(x) or (isdense(x) and x.ndim == 0)
216def isintlike(x) -> bool:
217 """Is x appropriate as an index into a sparse matrix? Returns True
218 if it can be cast safely to a machine int.
219 """
220 # Fast-path check to eliminate non-scalar values. operator.index would
221 # catch this case too, but the exception catching is slow.
222 if np.ndim(x) != 0:
223 return False
224 try:
225 operator.index(x)
226 except (TypeError, ValueError):
227 try:
228 loose_int = bool(int(x) == x)
229 except (TypeError, ValueError):
230 return False
231 if loose_int:
232 msg = "Inexact indices into sparse matrices are not allowed"
233 raise ValueError(msg)
234 return loose_int
235 return True
238def isshape(x, nonneg=False, allow_ndim=False) -> bool:
239 """Is x a valid tuple of dimensions?
241 If nonneg, also checks that the dimensions are non-negative.
242 If allow_ndim, shapes of any dimensionality are allowed.
243 """
244 ndim = len(x)
245 if not allow_ndim and ndim != 2:
246 return False
247 for d in x:
248 if not isintlike(d):
249 return False
250 if nonneg and d < 0:
251 return False
252 return True
255def issequence(t) -> bool:
256 return ((isinstance(t, (list, tuple)) and
257 (len(t) == 0 or np.isscalar(t[0]))) or
258 (isinstance(t, np.ndarray) and (t.ndim == 1)))
261def ismatrix(t) -> bool:
262 return ((isinstance(t, (list, tuple)) and
263 len(t) > 0 and issequence(t[0])) or
264 (isinstance(t, np.ndarray) and t.ndim == 2))
267def isdense(x) -> bool:
268 return isinstance(x, np.ndarray)
271def validateaxis(axis) -> None:
272 if axis is None:
273 return
274 axis_type = type(axis)
276 # In NumPy, you can pass in tuples for 'axis', but they are
277 # not very useful for sparse matrices given their limited
278 # dimensions, so let's make it explicit that they are not
279 # allowed to be passed in
280 if axis_type == tuple:
281 raise TypeError("Tuples are not accepted for the 'axis' parameter. "
282 "Please pass in one of the following: "
283 "{-2, -1, 0, 1, None}.")
285 # If not a tuple, check that the provided axis is actually
286 # an integer and raise a TypeError similar to NumPy's
287 if not np.issubdtype(np.dtype(axis_type), np.integer):
288 raise TypeError(f"axis must be an integer, not {axis_type.__name__}")
290 if not (-2 <= axis <= 1):
291 raise ValueError("axis out of range")
294def check_shape(args, current_shape=None):
295 """Imitate numpy.matrix handling of shape arguments"""
296 if len(args) == 0:
297 raise TypeError("function missing 1 required positional argument: "
298 "'shape'")
299 if len(args) == 1:
300 try:
301 shape_iter = iter(args[0])
302 except TypeError:
303 new_shape = (operator.index(args[0]), )
304 else:
305 new_shape = tuple(operator.index(arg) for arg in shape_iter)
306 else:
307 new_shape = tuple(operator.index(arg) for arg in args)
309 if current_shape is None:
310 if len(new_shape) != 2:
311 raise ValueError('shape must be a 2-tuple of positive integers')
312 elif any(d < 0 for d in new_shape):
313 raise ValueError("'shape' elements cannot be negative")
314 else:
315 # Check the current size only if needed
316 current_size = prod(current_shape)
318 # Check for negatives
319 negative_indexes = [i for i, x in enumerate(new_shape) if x < 0]
320 if not negative_indexes:
321 new_size = prod(new_shape)
322 if new_size != current_size:
323 raise ValueError('cannot reshape array of size {} into shape {}'
324 .format(current_size, new_shape))
325 elif len(negative_indexes) == 1:
326 skip = negative_indexes[0]
327 specified = prod(new_shape[:skip] + new_shape[skip+1:])
328 unspecified, remainder = divmod(current_size, specified)
329 if remainder != 0:
330 err_shape = tuple('newshape' if x < 0 else x for x in new_shape)
331 raise ValueError('cannot reshape array of size {} into shape {}'
332 ''.format(current_size, err_shape))
333 new_shape = new_shape[:skip] + (unspecified,) + new_shape[skip+1:]
334 else:
335 raise ValueError('can only specify one unknown dimension')
337 if len(new_shape) != 2:
338 raise ValueError('matrix shape must be two-dimensional')
340 return new_shape
343def check_reshape_kwargs(kwargs):
344 """Unpack keyword arguments for reshape function.
346 This is useful because keyword arguments after star arguments are not
347 allowed in Python 2, but star keyword arguments are. This function unpacks
348 'order' and 'copy' from the star keyword arguments (with defaults) and
349 throws an error for any remaining.
350 """
352 order = kwargs.pop('order', 'C')
353 copy = kwargs.pop('copy', False)
354 if kwargs: # Some unused kwargs remain
355 raise TypeError('reshape() got unexpected keywords arguments: {}'
356 .format(', '.join(kwargs.keys())))
357 return order, copy
360def is_pydata_spmatrix(m) -> bool:
361 """
362 Check whether object is pydata/sparse matrix, avoiding importing the module.
363 """
364 base_cls = getattr(sys.modules.get('sparse'), 'SparseArray', None)
365 return base_cls is not None and isinstance(m, base_cls)
368###############################################################################
369# Wrappers for NumPy types that are deprecated
371# Numpy versions of these functions raise deprecation warnings, the
372# ones below do not.
374def matrix(*args, **kwargs):
375 return np.array(*args, **kwargs).view(np.matrix)
378def asmatrix(data, dtype=None):
379 if isinstance(data, np.matrix) and (dtype is None or data.dtype == dtype):
380 return data
381 return np.asarray(data, dtype=dtype).view(np.matrix)
383###############################################################################
386def _todata(s) -> np.ndarray:
387 """Access nonzero values, possibly after summing duplicates.
389 Parameters
390 ----------
391 s : sparse array
392 Input sparse array.
394 Returns
395 -------
396 data: ndarray
397 Nonzero values of the array, with shape (s.nnz,)
399 """
400 if isinstance(s, sp._data._data_matrix):
401 return s._deduped_data()
403 if isinstance(s, sp.dok_array):
404 return np.fromiter(s.values(), dtype=s.dtype, count=s.nnz)
406 if isinstance(s, sp.lil_array):
407 data = np.empty(s.nnz, dtype=s.dtype)
408 sp._csparsetools.lil_flatten_to_array(s.data, data)
409 return data
411 return s.tocoo()._deduped_data()