Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scipy/fft/_pocketfft/helper.py: 20%
107 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 06:31 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 06:31 +0000
1from numbers import Number
2import operator
3import os
4import threading
5import contextlib
7import numpy as np
8# good_size is exposed (and used) from this import
9from .pypocketfft import good_size
11_config = threading.local()
12_cpu_count = os.cpu_count()
15def _iterable_of_int(x, name=None):
16 """Convert ``x`` to an iterable sequence of int
18 Parameters
19 ----------
20 x : value, or sequence of values, convertible to int
21 name : str, optional
22 Name of the argument being converted, only used in the error message
24 Returns
25 -------
26 y : ``List[int]``
27 """
28 if isinstance(x, Number):
29 x = (x,)
31 try:
32 x = [operator.index(a) for a in x]
33 except TypeError as e:
34 name = name or "value"
35 raise ValueError("{} must be a scalar or iterable of integers"
36 .format(name)) from e
38 return x
41def _init_nd_shape_and_axes(x, shape, axes):
42 """Handles shape and axes arguments for nd transforms"""
43 noshape = shape is None
44 noaxes = axes is None
46 if not noaxes:
47 axes = _iterable_of_int(axes, 'axes')
48 axes = [a + x.ndim if a < 0 else a for a in axes]
50 if any(a >= x.ndim or a < 0 for a in axes):
51 raise ValueError("axes exceeds dimensionality of input")
52 if len(set(axes)) != len(axes):
53 raise ValueError("all axes must be unique")
55 if not noshape:
56 shape = _iterable_of_int(shape, 'shape')
58 if axes and len(axes) != len(shape):
59 raise ValueError("when given, axes and shape arguments"
60 " have to be of the same length")
61 if noaxes:
62 if len(shape) > x.ndim:
63 raise ValueError("shape requires more axes than are present")
64 axes = range(x.ndim - len(shape), x.ndim)
66 shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)]
67 elif noaxes:
68 shape = list(x.shape)
69 axes = range(x.ndim)
70 else:
71 shape = [x.shape[a] for a in axes]
73 if any(s < 1 for s in shape):
74 raise ValueError(
75 "invalid number of data points ({0}) specified".format(shape))
77 return shape, axes
80def _asfarray(x):
81 """
82 Convert to array with floating or complex dtype.
84 float16 values are also promoted to float32.
85 """
86 if not hasattr(x, "dtype"):
87 x = np.asarray(x)
89 if x.dtype == np.float16:
90 return np.asarray(x, np.float32)
91 elif x.dtype.kind not in 'fc':
92 return np.asarray(x, np.float64)
94 # Require native byte order
95 dtype = x.dtype.newbyteorder('=')
96 # Always align input
97 copy = not x.flags['ALIGNED']
98 return np.array(x, dtype=dtype, copy=copy)
100def _datacopied(arr, original):
101 """
102 Strict check for `arr` not sharing any data with `original`,
103 under the assumption that arr = asarray(original)
104 """
105 if arr is original:
106 return False
107 if not isinstance(original, np.ndarray) and hasattr(original, '__array__'):
108 return False
109 return arr.base is None
112def _fix_shape(x, shape, axes):
113 """Internal auxiliary function for _raw_fft, _raw_fftnd."""
114 must_copy = False
116 # Build an nd slice with the dimensions to be read from x
117 index = [slice(None)]*x.ndim
118 for n, ax in zip(shape, axes):
119 if x.shape[ax] >= n:
120 index[ax] = slice(0, n)
121 else:
122 index[ax] = slice(0, x.shape[ax])
123 must_copy = True
125 index = tuple(index)
127 if not must_copy:
128 return x[index], False
130 s = list(x.shape)
131 for n, axis in zip(shape, axes):
132 s[axis] = n
134 z = np.zeros(s, x.dtype)
135 z[index] = x[index]
136 return z, True
139def _fix_shape_1d(x, n, axis):
140 if n < 1:
141 raise ValueError(
142 "invalid number of data points ({0}) specified".format(n))
144 return _fix_shape(x, (n,), (axis,))
147_NORM_MAP = {None: 0, 'backward': 0, 'ortho': 1, 'forward': 2}
150def _normalization(norm, forward):
151 """Returns the pypocketfft normalization mode from the norm argument"""
152 try:
153 inorm = _NORM_MAP[norm]
154 return inorm if forward else (2 - inorm)
155 except KeyError:
156 raise ValueError(
157 f'Invalid norm value {norm!r}, should '
158 'be "backward", "ortho" or "forward"') from None
161def _workers(workers):
162 if workers is None:
163 return getattr(_config, 'default_workers', 1)
165 if workers < 0:
166 if workers >= -_cpu_count:
167 workers += 1 + _cpu_count
168 else:
169 raise ValueError("workers value out of range; got {}, must not be"
170 " less than {}".format(workers, -_cpu_count))
171 elif workers == 0:
172 raise ValueError("workers must not be zero")
174 return workers
177@contextlib.contextmanager
178def set_workers(workers):
179 """Context manager for the default number of workers used in `scipy.fft`
181 Parameters
182 ----------
183 workers : int
184 The default number of workers to use
186 Examples
187 --------
188 >>> import numpy as np
189 >>> from scipy import fft, signal
190 >>> rng = np.random.default_rng()
191 >>> x = rng.standard_normal((128, 64))
192 >>> with fft.set_workers(4):
193 ... y = signal.fftconvolve(x, x)
195 """
196 old_workers = get_workers()
197 _config.default_workers = _workers(operator.index(workers))
198 try:
199 yield
200 finally:
201 _config.default_workers = old_workers
204def get_workers():
205 """Returns the default number of workers within the current context
207 Examples
208 --------
209 >>> from scipy import fft
210 >>> fft.get_workers()
211 1
212 >>> with fft.set_workers(4):
213 ... fft.get_workers()
214 4
215 """
216 return getattr(_config, 'default_workers', 1)