Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/array_api_compat/common/_aliases.py: 25%

211 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-23 06:43 +0000

1""" 

2These are functions that are just aliases of existing functions in NumPy. 

3""" 

4 

5from __future__ import annotations 

6 

7from typing import TYPE_CHECKING 

8if TYPE_CHECKING: 

9 from typing import Optional, Sequence, Tuple, Union, List 

10 from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol 

11 

12from typing import NamedTuple 

13from types import ModuleType 

14import inspect 

15 

16from ._helpers import _check_device, _is_numpy_array, array_namespace 

17 

18# These functions are modified from the NumPy versions. 

19 

20def arange( 

21 start: Union[int, float], 

22 /, 

23 stop: Optional[Union[int, float]] = None, 

24 step: Union[int, float] = 1, 

25 *, 

26 xp, 

27 dtype: Optional[Dtype] = None, 

28 device: Optional[Device] = None, 

29 **kwargs 

30) -> ndarray: 

31 _check_device(xp, device) 

32 return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) 

33 

34def empty( 

35 shape: Union[int, Tuple[int, ...]], 

36 xp, 

37 *, 

38 dtype: Optional[Dtype] = None, 

39 device: Optional[Device] = None, 

40 **kwargs 

41) -> ndarray: 

42 _check_device(xp, device) 

43 return xp.empty(shape, dtype=dtype, **kwargs) 

44 

45def empty_like( 

46 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, 

47 **kwargs 

48) -> ndarray: 

49 _check_device(xp, device) 

50 return xp.empty_like(x, dtype=dtype, **kwargs) 

51 

52def eye( 

53 n_rows: int, 

54 n_cols: Optional[int] = None, 

55 /, 

56 *, 

57 xp, 

58 k: int = 0, 

59 dtype: Optional[Dtype] = None, 

60 device: Optional[Device] = None, 

61 **kwargs, 

62) -> ndarray: 

63 _check_device(xp, device) 

64 return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) 

65 

66def full( 

67 shape: Union[int, Tuple[int, ...]], 

68 fill_value: Union[int, float], 

69 xp, 

70 *, 

71 dtype: Optional[Dtype] = None, 

72 device: Optional[Device] = None, 

73 **kwargs, 

74) -> ndarray: 

75 _check_device(xp, device) 

76 return xp.full(shape, fill_value, dtype=dtype, **kwargs) 

77 

78def full_like( 

79 x: ndarray, 

80 /, 

81 fill_value: Union[int, float], 

82 *, 

83 xp, 

84 dtype: Optional[Dtype] = None, 

85 device: Optional[Device] = None, 

86 **kwargs, 

87) -> ndarray: 

88 _check_device(xp, device) 

89 return xp.full_like(x, fill_value, dtype=dtype, **kwargs) 

90 

91def linspace( 

92 start: Union[int, float], 

93 stop: Union[int, float], 

94 /, 

95 num: int, 

96 *, 

97 xp, 

98 dtype: Optional[Dtype] = None, 

99 device: Optional[Device] = None, 

100 endpoint: bool = True, 

101 **kwargs, 

102) -> ndarray: 

103 _check_device(xp, device) 

104 return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) 

105 

106def ones( 

107 shape: Union[int, Tuple[int, ...]], 

108 xp, 

109 *, 

110 dtype: Optional[Dtype] = None, 

111 device: Optional[Device] = None, 

112 **kwargs, 

113) -> ndarray: 

114 _check_device(xp, device) 

115 return xp.ones(shape, dtype=dtype, **kwargs) 

116 

117def ones_like( 

118 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, 

119 **kwargs, 

120) -> ndarray: 

121 _check_device(xp, device) 

122 return xp.ones_like(x, dtype=dtype, **kwargs) 

123 

124def zeros( 

125 shape: Union[int, Tuple[int, ...]], 

126 xp, 

127 *, 

128 dtype: Optional[Dtype] = None, 

129 device: Optional[Device] = None, 

130 **kwargs, 

131) -> ndarray: 

132 _check_device(xp, device) 

133 return xp.zeros(shape, dtype=dtype, **kwargs) 

134 

135def zeros_like( 

136 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, 

137 **kwargs, 

138) -> ndarray: 

139 _check_device(xp, device) 

140 return xp.zeros_like(x, dtype=dtype, **kwargs) 

141 

142# np.unique() is split into four functions in the array API: 

143# unique_all, unique_counts, unique_inverse, and unique_values (this is done 

144# to remove polymorphic return types). 

145 

146# The functions here return namedtuples (np.unique() returns a normal 

147# tuple). 

148class UniqueAllResult(NamedTuple): 

149 values: ndarray 

150 indices: ndarray 

151 inverse_indices: ndarray 

152 counts: ndarray 

153 

154 

155class UniqueCountsResult(NamedTuple): 

156 values: ndarray 

157 counts: ndarray 

158 

159 

160class UniqueInverseResult(NamedTuple): 

161 values: ndarray 

162 inverse_indices: ndarray 

163 

164 

165def _unique_kwargs(xp): 

166 # Older versions of NumPy and CuPy do not have equal_nan. Rather than 

167 # trying to parse version numbers, just check if equal_nan is in the 

168 # signature. 

169 s = inspect.signature(xp.unique) 

170 if 'equal_nan' in s.parameters: 

171 return {'equal_nan': False} 

172 return {} 

173 

174def unique_all(x: ndarray, /, xp) -> UniqueAllResult: 

175 kwargs = _unique_kwargs(xp) 

176 values, indices, inverse_indices, counts = xp.unique( 

177 x, 

178 return_counts=True, 

179 return_index=True, 

180 return_inverse=True, 

181 **kwargs, 

182 ) 

183 # np.unique() flattens inverse indices, but they need to share x's shape 

184 # See https://github.com/numpy/numpy/issues/20638 

185 inverse_indices = inverse_indices.reshape(x.shape) 

186 return UniqueAllResult( 

187 values, 

188 indices, 

189 inverse_indices, 

190 counts, 

191 ) 

192 

193 

194def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: 

195 kwargs = _unique_kwargs(xp) 

196 res = xp.unique( 

197 x, 

198 return_counts=True, 

199 return_index=False, 

200 return_inverse=False, 

201 **kwargs 

202 ) 

203 

204 return UniqueCountsResult(*res) 

205 

206 

207def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: 

208 kwargs = _unique_kwargs(xp) 

209 values, inverse_indices = xp.unique( 

210 x, 

211 return_counts=False, 

212 return_index=False, 

213 return_inverse=True, 

214 **kwargs, 

215 ) 

216 # xp.unique() flattens inverse indices, but they need to share x's shape 

217 # See https://github.com/numpy/numpy/issues/20638 

218 inverse_indices = inverse_indices.reshape(x.shape) 

219 return UniqueInverseResult(values, inverse_indices) 

220 

221 

222def unique_values(x: ndarray, /, xp) -> ndarray: 

223 kwargs = _unique_kwargs(xp) 

224 return xp.unique( 

225 x, 

226 return_counts=False, 

227 return_index=False, 

228 return_inverse=False, 

229 **kwargs, 

230 ) 

231 

232def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray: 

233 if not copy and dtype == x.dtype: 

234 return x 

235 return x.astype(dtype=dtype, copy=copy) 

236 

237# These functions have different keyword argument names 

238 

239def std( 

240 x: ndarray, 

241 /, 

242 xp, 

243 *, 

244 axis: Optional[Union[int, Tuple[int, ...]]] = None, 

245 correction: Union[int, float] = 0.0, # correction instead of ddof 

246 keepdims: bool = False, 

247 **kwargs, 

248) -> ndarray: 

249 return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) 

250 

251def var( 

252 x: ndarray, 

253 /, 

254 xp, 

255 *, 

256 axis: Optional[Union[int, Tuple[int, ...]]] = None, 

257 correction: Union[int, float] = 0.0, # correction instead of ddof 

258 keepdims: bool = False, 

259 **kwargs, 

260) -> ndarray: 

261 return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) 

262 

263# Unlike transpose(), the axes argument to permute_dims() is required. 

264def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: 

265 return xp.transpose(x, axes) 

266 

267# Creation functions add the device keyword (which does nothing for NumPy) 

268 

269# asarray also adds the copy keyword 

270def _asarray( 

271 obj: Union[ 

272 ndarray, 

273 bool, 

274 int, 

275 float, 

276 NestedSequence[bool | int | float], 

277 SupportsBufferProtocol, 

278 ], 

279 /, 

280 *, 

281 dtype: Optional[Dtype] = None, 

282 device: Optional[Device] = None, 

283 copy: "Optional[Union[bool, np._CopyMode]]" = None, 

284 namespace = None, 

285 **kwargs, 

286) -> ndarray: 

287 """ 

288 Array API compatibility wrapper for asarray(). 

289 

290 See the corresponding documentation in NumPy/CuPy and/or the array API 

291 specification for more details. 

292 

293 """ 

294 if namespace is None: 

295 try: 

296 xp = array_namespace(obj, _use_compat=False) 

297 except ValueError: 

298 # TODO: What about lists of arrays? 

299 raise ValueError("A namespace must be specified for asarray() with non-array input") 

300 elif isinstance(namespace, ModuleType): 

301 xp = namespace 

302 elif namespace == 'numpy': 

303 import numpy as xp 

304 elif namespace == 'cupy': 

305 import cupy as xp 

306 else: 

307 raise ValueError("Unrecognized namespace argument to asarray()") 

308 

309 _check_device(xp, device) 

310 if _is_numpy_array(obj): 

311 import numpy as np 

312 if hasattr(np, '_CopyMode'): 

313 # Not present in older NumPys 

314 COPY_FALSE = (False, np._CopyMode.IF_NEEDED) 

315 COPY_TRUE = (True, np._CopyMode.ALWAYS) 

316 else: 

317 COPY_FALSE = (False,) 

318 COPY_TRUE = (True,) 

319 else: 

320 COPY_FALSE = (False,) 

321 COPY_TRUE = (True,) 

322 if copy in COPY_FALSE: 

323 # copy=False is not yet implemented in xp.asarray 

324 raise NotImplementedError("copy=False is not yet implemented") 

325 if isinstance(obj, xp.ndarray): 

326 if dtype is not None and obj.dtype != dtype: 

327 copy = True 

328 if copy in COPY_TRUE: 

329 return xp.array(obj, copy=True, dtype=dtype) 

330 return obj 

331 

332 return xp.asarray(obj, dtype=dtype, **kwargs) 

333 

334# np.reshape calls the keyword argument 'newshape' instead of 'shape' 

335def reshape(x: ndarray, 

336 /, 

337 shape: Tuple[int, ...], 

338 xp, copy: Optional[bool] = None, 

339 **kwargs) -> ndarray: 

340 if copy is True: 

341 x = x.copy() 

342 elif copy is False: 

343 y = x.view() 

344 y.shape = shape 

345 return y 

346 return xp.reshape(x, shape, **kwargs) 

347 

348# The descending keyword is new in sort and argsort, and 'kind' replaced with 

349# 'stable' 

350def argsort( 

351 x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, 

352 **kwargs, 

353) -> ndarray: 

354 # Note: this keyword argument is different, and the default is different. 

355 # We set it in kwargs like this because numpy.sort uses kind='quicksort' 

356 # as the default whereas cupy.sort uses kind=None. 

357 if stable: 

358 kwargs['kind'] = "stable" 

359 if not descending: 

360 res = xp.argsort(x, axis=axis, **kwargs) 

361 else: 

362 # As NumPy has no native descending sort, we imitate it here. Note that 

363 # simply flipping the results of xp.argsort(x, ...) would not 

364 # respect the relative order like it would in native descending sorts. 

365 res = xp.flip( 

366 xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs), 

367 axis=axis, 

368 ) 

369 # Rely on flip()/argsort() to validate axis 

370 normalised_axis = axis if axis >= 0 else x.ndim + axis 

371 max_i = x.shape[normalised_axis] - 1 

372 res = max_i - res 

373 return res 

374 

375def sort( 

376 x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, 

377 **kwargs, 

378) -> ndarray: 

379 # Note: this keyword argument is different, and the default is different. 

380 # We set it in kwargs like this because numpy.sort uses kind='quicksort' 

381 # as the default whereas cupy.sort uses kind=None. 

382 if stable: 

383 kwargs['kind'] = "stable" 

384 res = xp.sort(x, axis=axis, **kwargs) 

385 if descending: 

386 res = xp.flip(res, axis=axis) 

387 return res 

388 

389# sum() and prod() should always upcast when dtype=None 

390def sum( 

391 x: ndarray, 

392 /, 

393 xp, 

394 *, 

395 axis: Optional[Union[int, Tuple[int, ...]]] = None, 

396 dtype: Optional[Dtype] = None, 

397 keepdims: bool = False, 

398 **kwargs, 

399) -> ndarray: 

400 # `xp.sum` already upcasts integers, but not floats or complexes 

401 if dtype is None: 

402 if x.dtype == xp.float32: 

403 dtype = xp.float64 

404 elif x.dtype == xp.complex64: 

405 dtype = xp.complex128 

406 return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) 

407 

408def prod( 

409 x: ndarray, 

410 /, 

411 xp, 

412 *, 

413 axis: Optional[Union[int, Tuple[int, ...]]] = None, 

414 dtype: Optional[Dtype] = None, 

415 keepdims: bool = False, 

416 **kwargs, 

417) -> ndarray: 

418 if dtype is None: 

419 if x.dtype == xp.float32: 

420 dtype = xp.float64 

421 elif x.dtype == xp.complex64: 

422 dtype = xp.complex128 

423 return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs) 

424 

425# ceil, floor, and trunc return integers for integer inputs 

426 

427def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: 

428 if xp.issubdtype(x.dtype, xp.integer): 

429 return x 

430 return xp.ceil(x, **kwargs) 

431 

432def floor(x: ndarray, /, xp, **kwargs) -> ndarray: 

433 if xp.issubdtype(x.dtype, xp.integer): 

434 return x 

435 return xp.floor(x, **kwargs) 

436 

437def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: 

438 if xp.issubdtype(x.dtype, xp.integer): 

439 return x 

440 return xp.trunc(x, **kwargs) 

441 

442# linear algebra functions 

443 

444def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: 

445 return xp.matmul(x1, x2, **kwargs) 

446 

447# Unlike transpose, matrix_transpose only transposes the last two axes. 

448def matrix_transpose(x: ndarray, /, xp) -> ndarray: 

449 if x.ndim < 2: 

450 raise ValueError("x must be at least 2-dimensional for matrix_transpose") 

451 return xp.swapaxes(x, -1, -2) 

452 

453def tensordot(x1: ndarray, 

454 x2: ndarray, 

455 /, 

456 xp, 

457 *, 

458 axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, 

459 **kwargs, 

460) -> ndarray: 

461 return xp.tensordot(x1, x2, axes=axes, **kwargs) 

462 

463def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: 

464 ndim = max(x1.ndim, x2.ndim) 

465 x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) 

466 x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) 

467 if x1_shape[axis] != x2_shape[axis]: 

468 raise ValueError("x1 and x2 must have the same size along the given axis") 

469 

470 if hasattr(xp, 'broadcast_tensors'): 

471 _broadcast = xp.broadcast_tensors 

472 else: 

473 _broadcast = xp.broadcast_arrays 

474 

475 x1_, x2_ = _broadcast(x1, x2) 

476 x1_ = xp.moveaxis(x1_, axis, -1) 

477 x2_ = xp.moveaxis(x2_, axis, -1) 

478 

479 res = x1_[..., None, :] @ x2_[..., None] 

480 return res[..., 0, 0] 

481 

482# isdtype is a new function in the 2022.12 array API specification. 

483 

484def isdtype( 

485 dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, 

486 *, _tuple=True, # Disallow nested tuples 

487) -> bool: 

488 """ 

489 Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. 

490 

491 Note that outside of this function, this compat library does not yet fully 

492 support complex numbers. 

493 

494 See 

495 https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html 

496 for more details 

497 """ 

498 if isinstance(kind, tuple) and _tuple: 

499 return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) 

500 elif isinstance(kind, str): 

501 if kind == 'bool': 

502 return dtype == xp.bool_ 

503 elif kind == 'signed integer': 

504 return xp.issubdtype(dtype, xp.signedinteger) 

505 elif kind == 'unsigned integer': 

506 return xp.issubdtype(dtype, xp.unsignedinteger) 

507 elif kind == 'integral': 

508 return xp.issubdtype(dtype, xp.integer) 

509 elif kind == 'real floating': 

510 return xp.issubdtype(dtype, xp.floating) 

511 elif kind == 'complex floating': 

512 return xp.issubdtype(dtype, xp.complexfloating) 

513 elif kind == 'numeric': 

514 return xp.issubdtype(dtype, xp.number) 

515 else: 

516 raise ValueError(f"Unrecognized data type kind: {kind!r}") 

517 else: 

518 # This will allow things that aren't required by the spec, like 

519 # isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be 

520 # more strict here to match the type annotation? Note that the 

521 # numpy.array_api implementation will be very strict. 

522 return dtype == kind 

523 

524__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 

525 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 

526 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 

527 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 

528 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', 

529 'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul', 

530 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']