Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/_array_api.py: 19%

153 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-23 06:43 +0000

1"""Utility functions to use Python Array API compatible libraries. 

2 

3For the context about the Array API see: 

4https://data-apis.org/array-api/latest/purpose_and_scope.html 

5 

6The SciPy use case of the Array API is described on the following page: 

7https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy 

8""" 

9from __future__ import annotations 

10 

11import os 

12import warnings 

13 

14import numpy as np 

15from numpy.testing import assert_ 

16import scipy._lib.array_api_compat.array_api_compat as array_api_compat 

17from scipy._lib.array_api_compat.array_api_compat import size 

18import scipy._lib.array_api_compat.array_api_compat.numpy as array_api_compat_numpy 

19 

20__all__ = ['array_namespace', 'as_xparray', 'size'] 

21 

22 

23# To enable array API and strict array-like input validation 

24SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False) 

25# To control the default device - for use in the test suite only 

26SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu") 

27 

28_GLOBAL_CONFIG = { 

29 "SCIPY_ARRAY_API": SCIPY_ARRAY_API, 

30 "SCIPY_DEVICE": SCIPY_DEVICE, 

31} 

32 

33 

34def compliance_scipy(arrays): 

35 """Raise exceptions on known-bad subclasses. 

36 

37 The following subclasses are not supported and raise and error: 

38 - `np.ma.MaskedArray` 

39 - `numpy.matrix` 

40 - Any array-like which is not Array API compatible or coercible by numpy 

41 - object arrays 

42 """ 

43 for i in range(len(arrays)): 

44 array = arrays[i] 

45 if isinstance(array, np.ma.MaskedArray): 

46 raise TypeError("'numpy.ma.MaskedArray' are not supported") 

47 elif isinstance(array, np.matrix): 

48 raise TypeError("'numpy.matrix' are not supported") 

49 elif not array_api_compat.is_array_api_obj(array): 

50 try: 

51 array = np.asanyarray(array) 

52 except TypeError: 

53 raise TypeError("Array is not Array API compatible or " 

54 "coercible by numpy") 

55 if array.dtype is np.dtype('O'): 

56 raise TypeError("An argument was coerced to an object array, " 

57 "but object arrays are not supported.") 

58 arrays[i] = array 

59 elif array.dtype is np.dtype('O'): 

60 raise TypeError('object arrays are not supported') 

61 return arrays 

62 

63 

64def _check_finite(array, xp): 

65 """Check for NaNs or Infs.""" 

66 msg = "array must not contain infs or NaNs" 

67 try: 

68 if not xp.all(xp.isfinite(array)): 

69 raise ValueError(msg) 

70 except TypeError: 

71 raise ValueError(msg) 

72 

73 

74def array_namespace(*arrays): 

75 """Get the array API compatible namespace for the arrays xs. 

76 

77 Parameters 

78 ---------- 

79 *arrays : sequence of array_like 

80 Arrays used to infer the common namespace. 

81 

82 Returns 

83 ------- 

84 namespace : module 

85 Common namespace. 

86 

87 Notes 

88 ----- 

89 Thin wrapper around `array_api_compat.array_namespace`. 

90 

91 1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed 

92 dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``. 

93 2. `compliance_scipy` raise exceptions on known-bad subclasses. See 

94 it's definition for more details. 

95 

96 When the global switch is False, it defaults to the `numpy` namespace. 

97 In that case, there is no compliance check. This is a convenience to 

98 ease the adoption. Otherwise, arrays must comply with the new rules. 

99 """ 

100 if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]: 

101 # here we could wrap the namespace if needed 

102 return array_api_compat_numpy 

103 

104 arrays = [array for array in arrays if array is not None] 

105 

106 arrays = compliance_scipy(arrays) 

107 

108 return array_api_compat.array_namespace(*arrays) 

109 

110 

111def as_xparray( 

112 array, dtype=None, order=None, copy=None, *, xp=None, check_finite=False 

113): 

114 """SciPy-specific replacement for `np.asarray` with `order` and `check_finite`. 

115 

116 Memory layout parameter `order` is not exposed in the Array API standard. 

117 `order` is only enforced if the input array implementation 

118 is NumPy based, otherwise `order` is just silently ignored. 

119 

120 `check_finite` is also not a keyword in the array API standard; included 

121 here for convenience rather than that having to be a separate function 

122 call inside SciPy functions. 

123 """ 

124 if xp is None: 

125 xp = array_namespace(array) 

126 if xp.__name__ in {"numpy", "scipy._lib.array_api_compat.array_api_compat.numpy"}: 

127 # Use NumPy API to support order 

128 if copy is True: 

129 array = np.array(array, order=order, dtype=dtype) 

130 else: 

131 array = np.asarray(array, order=order, dtype=dtype) 

132 

133 # At this point array is a NumPy ndarray. We convert it to an array 

134 # container that is consistent with the input's namespace. 

135 array = xp.asarray(array) 

136 else: 

137 try: 

138 array = xp.asarray(array, dtype=dtype, copy=copy) 

139 except TypeError: 

140 coerced_xp = array_namespace(xp.asarray(3)) 

141 array = coerced_xp.asarray(array, dtype=dtype, copy=copy) 

142 

143 if check_finite: 

144 _check_finite(array, xp) 

145 

146 return array 

147 

148 

149def atleast_nd(x, *, ndim, xp=None): 

150 """Recursively expand the dimension to have at least `ndim`.""" 

151 if xp is None: 

152 xp = array_namespace(x) 

153 x = xp.asarray(x) 

154 if x.ndim < ndim: 

155 x = xp.expand_dims(x, axis=0) 

156 x = atleast_nd(x, ndim=ndim, xp=xp) 

157 return x 

158 

159 

160def copy(x, *, xp=None): 

161 """ 

162 Copies an array. 

163 

164 Parameters 

165 ---------- 

166 x : array 

167 

168 xp : array_namespace 

169 

170 Returns 

171 ------- 

172 copy : array 

173 Copied array 

174 

175 Notes 

176 ----- 

177 This copy function does not offer all the semantics of `np.copy`, i.e. the 

178 `subok` and `order` keywords are not used. 

179 """ 

180 # Note: xp.asarray fails if xp is numpy. 

181 if xp is None: 

182 xp = array_namespace(x) 

183 

184 return as_xparray(x, copy=True, xp=xp) 

185 

186 

187def is_numpy(xp): 

188 return xp.__name__ == 'scipy._lib.array_api_compat.array_api_compat.numpy' 

189 

190 

191def is_cupy(xp): 

192 return xp.__name__ == 'scipy._lib.array_api_compat.array_api_compat.cupy' 

193 

194 

195def is_torch(xp): 

196 return xp.__name__ == 'scipy._lib.array_api_compat.array_api_compat.torch' 

197 

198 

199def _strict_check(actual, desired, xp, 

200 check_namespace=True, check_dtype=True, check_shape=True): 

201 if check_namespace: 

202 _assert_matching_namespace(actual, desired) 

203 

204 desired = xp.asarray(desired) 

205 

206 if check_dtype: 

207 assert_(actual.dtype == desired.dtype, 

208 "dtypes do not match.\n" 

209 f"Actual: {actual.dtype}\n" 

210 f"Desired: {desired.dtype}") 

211 

212 if check_shape: 

213 assert_(actual.shape == desired.shape, 

214 "Shapes do not match.\n" 

215 f"Actual: {actual.shape}\n" 

216 f"Desired: {desired.shape}") 

217 _check_scalar(actual, desired, xp) 

218 

219 desired = xp.broadcast_to(desired, actual.shape) 

220 return desired 

221 

222 

223def _assert_matching_namespace(actual, desired): 

224 actual = actual if isinstance(actual, tuple) else (actual,) 

225 desired_space = array_namespace(desired) 

226 for arr in actual: 

227 arr_space = array_namespace(arr) 

228 assert_(arr_space == desired_space, 

229 "Namespaces do not match.\n" 

230 f"Actual: {arr_space.__name__}\n" 

231 f"Desired: {desired_space.__name__}") 

232 

233 

234def _check_scalar(actual, desired, xp): 

235 # Shape check alone is sufficient unless desired.shape == (). Also, 

236 # only NumPy distinguishes between scalars and arrays. 

237 if desired.shape != () or not is_numpy(xp): 

238 return 

239 # We want to follow the conventions of the `xp` library. Libraries like 

240 # NumPy, for which `np.asarray(0)[()]` returns a scalar, tend to return 

241 # a scalar even when a 0D array might be more appropriate: 

242 # import numpy as np 

243 # np.mean([1, 2, 3]) # scalar, not 0d array 

244 # np.asarray(0)*2 # scalar, not 0d array 

245 # np.sin(np.asarray(0)) # scalar, not 0d array 

246 # Libraries like CuPy, for which `cp.asarray(0)[()]` returns a 0D array, 

247 # tend to return a 0D array in scenarios like those above. 

248 # Therefore, regardless of whether the developer provides a scalar or 0D 

249 # array for `desired`, we would typically want the type of `actual` to be 

250 # the type of `desired[()]`. If the developer wants to override this 

251 # behavior, they can set `check_shape=False`. 

252 desired = desired[()] 

253 assert_((xp.isscalar(actual) and xp.isscalar(desired) 

254 or (not xp.isscalar(actual) and not xp.isscalar(desired))), 

255 "Types do not match:\n" 

256 f"Actual: {type(actual)}\n" 

257 f"Desired: {type(desired)}") 

258 

259 

260def xp_assert_equal(actual, desired, check_namespace=True, check_dtype=True, 

261 check_shape=True, err_msg='', xp=None): 

262 if xp is None: 

263 xp = array_namespace(actual) 

264 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace, 

265 check_dtype=check_dtype, check_shape=check_shape) 

266 if is_cupy(xp): 

267 return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg) 

268 elif is_torch(xp): 

269 # PyTorch recommends using `rtol=0, atol=0` like this 

270 # to test for exact equality 

271 err_msg = None if err_msg == '' else err_msg 

272 return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True, 

273 check_dtype=False, msg=err_msg) 

274 return np.testing.assert_array_equal(actual, desired, err_msg=err_msg) 

275 

276 

277def xp_assert_close(actual, desired, rtol=1e-07, atol=0, check_namespace=True, 

278 check_dtype=True, check_shape=True, err_msg='', xp=None): 

279 if xp is None: 

280 xp = array_namespace(actual) 

281 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace, 

282 check_dtype=check_dtype, check_shape=check_shape) 

283 if is_cupy(xp): 

284 return xp.testing.assert_allclose(actual, desired, rtol=rtol, 

285 atol=atol, err_msg=err_msg) 

286 elif is_torch(xp): 

287 err_msg = None if err_msg == '' else err_msg 

288 return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol, 

289 equal_nan=True, check_dtype=False, msg=err_msg) 

290 return np.testing.assert_allclose(actual, desired, rtol=rtol, 

291 atol=atol, err_msg=err_msg) 

292 

293 

294def xp_assert_less(actual, desired, check_namespace=True, check_dtype=True, 

295 check_shape=True, err_msg='', verbose=True, xp=None): 

296 if xp is None: 

297 xp = array_namespace(actual) 

298 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace, 

299 check_dtype=check_dtype, check_shape=check_shape) 

300 if is_cupy(xp): 

301 return xp.testing.assert_array_less(actual, desired, 

302 err_msg=err_msg, verbose=verbose) 

303 elif is_torch(xp): 

304 if actual.device.type != 'cpu': 

305 actual = actual.cpu() 

306 if desired.device.type != 'cpu': 

307 desired = desired.cpu() 

308 return np.testing.assert_array_less(actual, desired, 

309 err_msg=err_msg, verbose=verbose) 

310 

311 

312def cov(x, *, xp=None): 

313 if xp is None: 

314 xp = array_namespace(x) 

315 

316 X = copy(x, xp=xp) 

317 dtype = xp.result_type(X, xp.float64) 

318 

319 X = atleast_nd(X, ndim=2, xp=xp) 

320 X = xp.asarray(X, dtype=dtype) 

321 

322 avg = xp.mean(X, axis=1) 

323 fact = X.shape[1] - 1 

324 

325 if fact <= 0: 

326 warnings.warn("Degrees of freedom <= 0 for slice", 

327 RuntimeWarning, stacklevel=2) 

328 fact = 0.0 

329 

330 X -= avg[:, None] 

331 X_T = X.T 

332 if xp.isdtype(X_T.dtype, 'complex floating'): 

333 X_T = xp.conj(X_T) 

334 c = X @ X_T 

335 c /= fact 

336 axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1) 

337 return xp.squeeze(c, axis=axes) 

338 

339 

340def xp_unsupported_param_msg(param): 

341 return f'Providing {param!r} is only supported for numpy arrays.'