Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scipy/fft/_pocketfft/helper.py: 20%

107 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1from numbers import Number 

2import operator 

3import os 

4import threading 

5import contextlib 

6 

7import numpy as np 

8# good_size is exposed (and used) from this import 

9from .pypocketfft import good_size 

10 

11_config = threading.local() 

12_cpu_count = os.cpu_count() 

13 

14 

15def _iterable_of_int(x, name=None): 

16 """Convert ``x`` to an iterable sequence of int 

17 

18 Parameters 

19 ---------- 

20 x : value, or sequence of values, convertible to int 

21 name : str, optional 

22 Name of the argument being converted, only used in the error message 

23 

24 Returns 

25 ------- 

26 y : ``List[int]`` 

27 """ 

28 if isinstance(x, Number): 

29 x = (x,) 

30 

31 try: 

32 x = [operator.index(a) for a in x] 

33 except TypeError as e: 

34 name = name or "value" 

35 raise ValueError("{} must be a scalar or iterable of integers" 

36 .format(name)) from e 

37 

38 return x 

39 

40 

41def _init_nd_shape_and_axes(x, shape, axes): 

42 """Handles shape and axes arguments for nd transforms""" 

43 noshape = shape is None 

44 noaxes = axes is None 

45 

46 if not noaxes: 

47 axes = _iterable_of_int(axes, 'axes') 

48 axes = [a + x.ndim if a < 0 else a for a in axes] 

49 

50 if any(a >= x.ndim or a < 0 for a in axes): 

51 raise ValueError("axes exceeds dimensionality of input") 

52 if len(set(axes)) != len(axes): 

53 raise ValueError("all axes must be unique") 

54 

55 if not noshape: 

56 shape = _iterable_of_int(shape, 'shape') 

57 

58 if axes and len(axes) != len(shape): 

59 raise ValueError("when given, axes and shape arguments" 

60 " have to be of the same length") 

61 if noaxes: 

62 if len(shape) > x.ndim: 

63 raise ValueError("shape requires more axes than are present") 

64 axes = range(x.ndim - len(shape), x.ndim) 

65 

66 shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)] 

67 elif noaxes: 

68 shape = list(x.shape) 

69 axes = range(x.ndim) 

70 else: 

71 shape = [x.shape[a] for a in axes] 

72 

73 if any(s < 1 for s in shape): 

74 raise ValueError( 

75 "invalid number of data points ({0}) specified".format(shape)) 

76 

77 return shape, axes 

78 

79 

80def _asfarray(x): 

81 """ 

82 Convert to array with floating or complex dtype. 

83 

84 float16 values are also promoted to float32. 

85 """ 

86 if not hasattr(x, "dtype"): 

87 x = np.asarray(x) 

88 

89 if x.dtype == np.float16: 

90 return np.asarray(x, np.float32) 

91 elif x.dtype.kind not in 'fc': 

92 return np.asarray(x, np.float64) 

93 

94 # Require native byte order 

95 dtype = x.dtype.newbyteorder('=') 

96 # Always align input 

97 copy = not x.flags['ALIGNED'] 

98 return np.array(x, dtype=dtype, copy=copy) 

99 

100def _datacopied(arr, original): 

101 """ 

102 Strict check for `arr` not sharing any data with `original`, 

103 under the assumption that arr = asarray(original) 

104 """ 

105 if arr is original: 

106 return False 

107 if not isinstance(original, np.ndarray) and hasattr(original, '__array__'): 

108 return False 

109 return arr.base is None 

110 

111 

112def _fix_shape(x, shape, axes): 

113 """Internal auxiliary function for _raw_fft, _raw_fftnd.""" 

114 must_copy = False 

115 

116 # Build an nd slice with the dimensions to be read from x 

117 index = [slice(None)]*x.ndim 

118 for n, ax in zip(shape, axes): 

119 if x.shape[ax] >= n: 

120 index[ax] = slice(0, n) 

121 else: 

122 index[ax] = slice(0, x.shape[ax]) 

123 must_copy = True 

124 

125 index = tuple(index) 

126 

127 if not must_copy: 

128 return x[index], False 

129 

130 s = list(x.shape) 

131 for n, axis in zip(shape, axes): 

132 s[axis] = n 

133 

134 z = np.zeros(s, x.dtype) 

135 z[index] = x[index] 

136 return z, True 

137 

138 

139def _fix_shape_1d(x, n, axis): 

140 if n < 1: 

141 raise ValueError( 

142 "invalid number of data points ({0}) specified".format(n)) 

143 

144 return _fix_shape(x, (n,), (axis,)) 

145 

146 

147_NORM_MAP = {None: 0, 'backward': 0, 'ortho': 1, 'forward': 2} 

148 

149 

150def _normalization(norm, forward): 

151 """Returns the pypocketfft normalization mode from the norm argument""" 

152 try: 

153 inorm = _NORM_MAP[norm] 

154 return inorm if forward else (2 - inorm) 

155 except KeyError: 

156 raise ValueError( 

157 f'Invalid norm value {norm!r}, should ' 

158 'be "backward", "ortho" or "forward"') from None 

159 

160 

161def _workers(workers): 

162 if workers is None: 

163 return getattr(_config, 'default_workers', 1) 

164 

165 if workers < 0: 

166 if workers >= -_cpu_count: 

167 workers += 1 + _cpu_count 

168 else: 

169 raise ValueError("workers value out of range; got {}, must not be" 

170 " less than {}".format(workers, -_cpu_count)) 

171 elif workers == 0: 

172 raise ValueError("workers must not be zero") 

173 

174 return workers 

175 

176 

177@contextlib.contextmanager 

178def set_workers(workers): 

179 """Context manager for the default number of workers used in `scipy.fft` 

180 

181 Parameters 

182 ---------- 

183 workers : int 

184 The default number of workers to use 

185 

186 Examples 

187 -------- 

188 >>> import numpy as np 

189 >>> from scipy import fft, signal 

190 >>> rng = np.random.default_rng() 

191 >>> x = rng.standard_normal((128, 64)) 

192 >>> with fft.set_workers(4): 

193 ... y = signal.fftconvolve(x, x) 

194 

195 """ 

196 old_workers = get_workers() 

197 _config.default_workers = _workers(operator.index(workers)) 

198 try: 

199 yield 

200 finally: 

201 _config.default_workers = old_workers 

202 

203 

204def get_workers(): 

205 """Returns the default number of workers within the current context 

206 

207 Examples 

208 -------- 

209 >>> from scipy import fft 

210 >>> fft.get_workers() 

211 1 

212 >>> with fft.set_workers(4): 

213 ... fft.get_workers() 

214 4 

215 """ 

216 return getattr(_config, 'default_workers', 1)