Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/_array_api.py: 4%

167 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-03 06:39 +0000

1"""Utility functions to use Python Array API compatible libraries. 

2 

3For the context about the Array API see: 

4https://data-apis.org/array-api/latest/purpose_and_scope.html 

5 

6The SciPy use case of the Array API is described on the following page: 

7https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy 

8""" 

9from __future__ import annotations 

10 

11import os 

12import warnings 

13 

14import numpy as np 

15 

16from scipy._lib import array_api_compat 

17from scipy._lib.array_api_compat import ( 

18 is_array_api_obj, 

19 size, 

20 numpy as np_compat, 

21) 

22 

23__all__ = ['array_namespace', '_asarray', 'size'] 

24 

25 

26# To enable array API and strict array-like input validation 

27SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False) 

28# To control the default device - for use in the test suite only 

29SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu") 

30 

31_GLOBAL_CONFIG = { 

32 "SCIPY_ARRAY_API": SCIPY_ARRAY_API, 

33 "SCIPY_DEVICE": SCIPY_DEVICE, 

34} 

35 

36 

37def compliance_scipy(arrays): 

38 """Raise exceptions on known-bad subclasses. 

39 

40 The following subclasses are not supported and raise and error: 

41 - `numpy.ma.MaskedArray` 

42 - `numpy.matrix` 

43 - NumPy arrays which do not have a boolean or numerical dtype 

44 - Any array-like which is neither array API compatible nor coercible by NumPy 

45 - Any array-like which is coerced by NumPy to an unsupported dtype 

46 """ 

47 for i in range(len(arrays)): 

48 array = arrays[i] 

49 if isinstance(array, np.ma.MaskedArray): 

50 raise TypeError("Inputs of type `numpy.ma.MaskedArray` are not supported.") 

51 elif isinstance(array, np.matrix): 

52 raise TypeError("Inputs of type `numpy.matrix` are not supported.") 

53 if isinstance(array, (np.ndarray, np.generic)): 

54 dtype = array.dtype 

55 if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)): 

56 raise TypeError(f"An argument has dtype `{dtype!r}`; " 

57 f"only boolean and numerical dtypes are supported.") 

58 elif not is_array_api_obj(array): 

59 try: 

60 array = np.asanyarray(array) 

61 except TypeError: 

62 raise TypeError("An argument is neither array API compatible nor " 

63 "coercible by NumPy.") 

64 dtype = array.dtype 

65 if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)): 

66 message = ( 

67 f"An argument was coerced to an unsupported dtype `{dtype!r}`; " 

68 f"only boolean and numerical dtypes are supported." 

69 ) 

70 raise TypeError(message) 

71 arrays[i] = array 

72 return arrays 

73 

74 

75def _check_finite(array, xp): 

76 """Check for NaNs or Infs.""" 

77 msg = "array must not contain infs or NaNs" 

78 try: 

79 if not xp.all(xp.isfinite(array)): 

80 raise ValueError(msg) 

81 except TypeError: 

82 raise ValueError(msg) 

83 

84 

85def array_namespace(*arrays): 

86 """Get the array API compatible namespace for the arrays xs. 

87 

88 Parameters 

89 ---------- 

90 *arrays : sequence of array_like 

91 Arrays used to infer the common namespace. 

92 

93 Returns 

94 ------- 

95 namespace : module 

96 Common namespace. 

97 

98 Notes 

99 ----- 

100 Thin wrapper around `array_api_compat.array_namespace`. 

101 

102 1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed 

103 dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``. 

104 2. `compliance_scipy` raise exceptions on known-bad subclasses. See 

105 its definition for more details. 

106 

107 When the global switch is False, it defaults to the `numpy` namespace. 

108 In that case, there is no compliance check. This is a convenience to 

109 ease the adoption. Otherwise, arrays must comply with the new rules. 

110 """ 

111 if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]: 

112 # here we could wrap the namespace if needed 

113 return np_compat 

114 

115 arrays = [array for array in arrays if array is not None] 

116 

117 arrays = compliance_scipy(arrays) 

118 

119 return array_api_compat.array_namespace(*arrays) 

120 

121 

122def _asarray( 

123 array, dtype=None, order=None, copy=None, *, xp=None, check_finite=False 

124): 

125 """SciPy-specific replacement for `np.asarray` with `order` and `check_finite`. 

126 

127 Memory layout parameter `order` is not exposed in the Array API standard. 

128 `order` is only enforced if the input array implementation 

129 is NumPy based, otherwise `order` is just silently ignored. 

130 

131 `check_finite` is also not a keyword in the array API standard; included 

132 here for convenience rather than that having to be a separate function 

133 call inside SciPy functions. 

134 """ 

135 if xp is None: 

136 xp = array_namespace(array) 

137 if xp.__name__ in {"numpy", "scipy._lib.array_api_compat.numpy"}: 

138 # Use NumPy API to support order 

139 if copy is True: 

140 array = np.array(array, order=order, dtype=dtype) 

141 else: 

142 array = np.asarray(array, order=order, dtype=dtype) 

143 

144 # At this point array is a NumPy ndarray. We convert it to an array 

145 # container that is consistent with the input's namespace. 

146 array = xp.asarray(array) 

147 else: 

148 try: 

149 array = xp.asarray(array, dtype=dtype, copy=copy) 

150 except TypeError: 

151 coerced_xp = array_namespace(xp.asarray(3)) 

152 array = coerced_xp.asarray(array, dtype=dtype, copy=copy) 

153 

154 if check_finite: 

155 _check_finite(array, xp) 

156 

157 return array 

158 

159 

160def atleast_nd(x, *, ndim, xp=None): 

161 """Recursively expand the dimension to have at least `ndim`.""" 

162 if xp is None: 

163 xp = array_namespace(x) 

164 x = xp.asarray(x) 

165 if x.ndim < ndim: 

166 x = xp.expand_dims(x, axis=0) 

167 x = atleast_nd(x, ndim=ndim, xp=xp) 

168 return x 

169 

170 

171def copy(x, *, xp=None): 

172 """ 

173 Copies an array. 

174 

175 Parameters 

176 ---------- 

177 x : array 

178 

179 xp : array_namespace 

180 

181 Returns 

182 ------- 

183 copy : array 

184 Copied array 

185 

186 Notes 

187 ----- 

188 This copy function does not offer all the semantics of `np.copy`, i.e. the 

189 `subok` and `order` keywords are not used. 

190 """ 

191 # Note: xp.asarray fails if xp is numpy. 

192 if xp is None: 

193 xp = array_namespace(x) 

194 

195 return _asarray(x, copy=True, xp=xp) 

196 

197 

198def is_numpy(xp): 

199 return xp.__name__ in ('numpy', 'scipy._lib.array_api_compat.numpy') 

200 

201 

202def is_cupy(xp): 

203 return xp.__name__ in ('cupy', 'scipy._lib.array_api_compat.cupy') 

204 

205 

206def is_torch(xp): 

207 return xp.__name__ in ('torch', 'scipy._lib.array_api_compat.torch') 

208 

209 

210def _strict_check(actual, desired, xp, 

211 check_namespace=True, check_dtype=True, check_shape=True): 

212 __tracebackhide__ = True # Hide traceback for py.test 

213 if check_namespace: 

214 _assert_matching_namespace(actual, desired) 

215 

216 desired = xp.asarray(desired) 

217 

218 if check_dtype: 

219 _msg = f"dtypes do not match.\nActual: {actual.dtype}\nDesired: {desired.dtype}" 

220 assert actual.dtype == desired.dtype, _msg 

221 

222 if check_shape: 

223 _msg = f"Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}" 

224 assert actual.shape == desired.shape, _msg 

225 _check_scalar(actual, desired, xp) 

226 

227 desired = xp.broadcast_to(desired, actual.shape) 

228 return desired 

229 

230 

231def _assert_matching_namespace(actual, desired): 

232 __tracebackhide__ = True # Hide traceback for py.test 

233 actual = actual if isinstance(actual, tuple) else (actual,) 

234 desired_space = array_namespace(desired) 

235 for arr in actual: 

236 arr_space = array_namespace(arr) 

237 _msg = (f"Namespaces do not match.\n" 

238 f"Actual: {arr_space.__name__}\n" 

239 f"Desired: {desired_space.__name__}") 

240 assert arr_space == desired_space, _msg 

241 

242 

243def _check_scalar(actual, desired, xp): 

244 __tracebackhide__ = True # Hide traceback for py.test 

245 # Shape check alone is sufficient unless desired.shape == (). Also, 

246 # only NumPy distinguishes between scalars and arrays. 

247 if desired.shape != () or not is_numpy(xp): 

248 return 

249 # We want to follow the conventions of the `xp` library. Libraries like 

250 # NumPy, for which `np.asarray(0)[()]` returns a scalar, tend to return 

251 # a scalar even when a 0D array might be more appropriate: 

252 # import numpy as np 

253 # np.mean([1, 2, 3]) # scalar, not 0d array 

254 # np.asarray(0)*2 # scalar, not 0d array 

255 # np.sin(np.asarray(0)) # scalar, not 0d array 

256 # Libraries like CuPy, for which `cp.asarray(0)[()]` returns a 0D array, 

257 # tend to return a 0D array in scenarios like those above. 

258 # Therefore, regardless of whether the developer provides a scalar or 0D 

259 # array for `desired`, we would typically want the type of `actual` to be 

260 # the type of `desired[()]`. If the developer wants to override this 

261 # behavior, they can set `check_shape=False`. 

262 desired = desired[()] 

263 _msg = f"Types do not match:\n Actual: {type(actual)}\n Desired: {type(desired)}" 

264 assert (xp.isscalar(actual) and xp.isscalar(desired) 

265 or (not xp.isscalar(actual) and not xp.isscalar(desired))), _msg 

266 

267 

268def xp_assert_equal(actual, desired, check_namespace=True, check_dtype=True, 

269 check_shape=True, err_msg='', xp=None): 

270 __tracebackhide__ = True # Hide traceback for py.test 

271 if xp is None: 

272 xp = array_namespace(actual) 

273 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace, 

274 check_dtype=check_dtype, check_shape=check_shape) 

275 if is_cupy(xp): 

276 return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg) 

277 elif is_torch(xp): 

278 # PyTorch recommends using `rtol=0, atol=0` like this 

279 # to test for exact equality 

280 err_msg = None if err_msg == '' else err_msg 

281 return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True, 

282 check_dtype=False, msg=err_msg) 

283 return np.testing.assert_array_equal(actual, desired, err_msg=err_msg) 

284 

285 

286def xp_assert_close(actual, desired, rtol=1e-07, atol=0, check_namespace=True, 

287 check_dtype=True, check_shape=True, err_msg='', xp=None): 

288 __tracebackhide__ = True # Hide traceback for py.test 

289 if xp is None: 

290 xp = array_namespace(actual) 

291 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace, 

292 check_dtype=check_dtype, check_shape=check_shape) 

293 if is_cupy(xp): 

294 return xp.testing.assert_allclose(actual, desired, rtol=rtol, 

295 atol=atol, err_msg=err_msg) 

296 elif is_torch(xp): 

297 err_msg = None if err_msg == '' else err_msg 

298 return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol, 

299 equal_nan=True, check_dtype=False, msg=err_msg) 

300 return np.testing.assert_allclose(actual, desired, rtol=rtol, 

301 atol=atol, err_msg=err_msg) 

302 

303 

304def xp_assert_less(actual, desired, check_namespace=True, check_dtype=True, 

305 check_shape=True, err_msg='', verbose=True, xp=None): 

306 __tracebackhide__ = True # Hide traceback for py.test 

307 if xp is None: 

308 xp = array_namespace(actual) 

309 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace, 

310 check_dtype=check_dtype, check_shape=check_shape) 

311 if is_cupy(xp): 

312 return xp.testing.assert_array_less(actual, desired, 

313 err_msg=err_msg, verbose=verbose) 

314 elif is_torch(xp): 

315 if actual.device.type != 'cpu': 

316 actual = actual.cpu() 

317 if desired.device.type != 'cpu': 

318 desired = desired.cpu() 

319 return np.testing.assert_array_less(actual, desired, 

320 err_msg=err_msg, verbose=verbose) 

321 

322 

323def cov(x, *, xp=None): 

324 if xp is None: 

325 xp = array_namespace(x) 

326 

327 X = copy(x, xp=xp) 

328 dtype = xp.result_type(X, xp.float64) 

329 

330 X = atleast_nd(X, ndim=2, xp=xp) 

331 X = xp.asarray(X, dtype=dtype) 

332 

333 avg = xp.mean(X, axis=1) 

334 fact = X.shape[1] - 1 

335 

336 if fact <= 0: 

337 warnings.warn("Degrees of freedom <= 0 for slice", 

338 RuntimeWarning, stacklevel=2) 

339 fact = 0.0 

340 

341 X -= avg[:, None] 

342 X_T = X.T 

343 if xp.isdtype(X_T.dtype, 'complex floating'): 

344 X_T = xp.conj(X_T) 

345 c = X @ X_T 

346 c /= fact 

347 axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1) 

348 return xp.squeeze(c, axis=axes) 

349 

350 

351def xp_unsupported_param_msg(param): 

352 return f'Providing {param!r} is only supported for numpy arrays.' 

353 

354 

355def is_complex(x, xp): 

356 return xp.isdtype(x.dtype, 'complex floating')