Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/common/_aliases.py: 25%

215 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-14 06:37 +0000

1""" 

2These are functions that are just aliases of existing functions in NumPy. 

3""" 

4 

5from __future__ import annotations 

6 

7from typing import TYPE_CHECKING 

8if TYPE_CHECKING: 

9 from typing import Optional, Sequence, Tuple, Union, List 

10 from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol 

11 

12from typing import NamedTuple 

13from types import ModuleType 

14import inspect 

15 

16from ._helpers import _check_device, _is_numpy_array, array_namespace 

17 

18# These functions are modified from the NumPy versions. 

19 

20def arange( 

21 start: Union[int, float], 

22 /, 

23 stop: Optional[Union[int, float]] = None, 

24 step: Union[int, float] = 1, 

25 *, 

26 xp, 

27 dtype: Optional[Dtype] = None, 

28 device: Optional[Device] = None, 

29 **kwargs 

30) -> ndarray: 

31 _check_device(xp, device) 

32 return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) 

33 

34def empty( 

35 shape: Union[int, Tuple[int, ...]], 

36 xp, 

37 *, 

38 dtype: Optional[Dtype] = None, 

39 device: Optional[Device] = None, 

40 **kwargs 

41) -> ndarray: 

42 _check_device(xp, device) 

43 return xp.empty(shape, dtype=dtype, **kwargs) 

44 

45def empty_like( 

46 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, 

47 **kwargs 

48) -> ndarray: 

49 _check_device(xp, device) 

50 return xp.empty_like(x, dtype=dtype, **kwargs) 

51 

52def eye( 

53 n_rows: int, 

54 n_cols: Optional[int] = None, 

55 /, 

56 *, 

57 xp, 

58 k: int = 0, 

59 dtype: Optional[Dtype] = None, 

60 device: Optional[Device] = None, 

61 **kwargs, 

62) -> ndarray: 

63 _check_device(xp, device) 

64 return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) 

65 

66def full( 

67 shape: Union[int, Tuple[int, ...]], 

68 fill_value: Union[int, float], 

69 xp, 

70 *, 

71 dtype: Optional[Dtype] = None, 

72 device: Optional[Device] = None, 

73 **kwargs, 

74) -> ndarray: 

75 _check_device(xp, device) 

76 return xp.full(shape, fill_value, dtype=dtype, **kwargs) 

77 

78def full_like( 

79 x: ndarray, 

80 /, 

81 fill_value: Union[int, float], 

82 *, 

83 xp, 

84 dtype: Optional[Dtype] = None, 

85 device: Optional[Device] = None, 

86 **kwargs, 

87) -> ndarray: 

88 _check_device(xp, device) 

89 return xp.full_like(x, fill_value, dtype=dtype, **kwargs) 

90 

91def linspace( 

92 start: Union[int, float], 

93 stop: Union[int, float], 

94 /, 

95 num: int, 

96 *, 

97 xp, 

98 dtype: Optional[Dtype] = None, 

99 device: Optional[Device] = None, 

100 endpoint: bool = True, 

101 **kwargs, 

102) -> ndarray: 

103 _check_device(xp, device) 

104 return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) 

105 

106def ones( 

107 shape: Union[int, Tuple[int, ...]], 

108 xp, 

109 *, 

110 dtype: Optional[Dtype] = None, 

111 device: Optional[Device] = None, 

112 **kwargs, 

113) -> ndarray: 

114 _check_device(xp, device) 

115 return xp.ones(shape, dtype=dtype, **kwargs) 

116 

117def ones_like( 

118 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, 

119 **kwargs, 

120) -> ndarray: 

121 _check_device(xp, device) 

122 return xp.ones_like(x, dtype=dtype, **kwargs) 

123 

124def zeros( 

125 shape: Union[int, Tuple[int, ...]], 

126 xp, 

127 *, 

128 dtype: Optional[Dtype] = None, 

129 device: Optional[Device] = None, 

130 **kwargs, 

131) -> ndarray: 

132 _check_device(xp, device) 

133 return xp.zeros(shape, dtype=dtype, **kwargs) 

134 

135def zeros_like( 

136 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, 

137 **kwargs, 

138) -> ndarray: 

139 _check_device(xp, device) 

140 return xp.zeros_like(x, dtype=dtype, **kwargs) 

141 

142# np.unique() is split into four functions in the array API: 

143# unique_all, unique_counts, unique_inverse, and unique_values (this is done 

144# to remove polymorphic return types). 

145 

146# The functions here return namedtuples (np.unique() returns a normal 

147# tuple). 

148class UniqueAllResult(NamedTuple): 

149 values: ndarray 

150 indices: ndarray 

151 inverse_indices: ndarray 

152 counts: ndarray 

153 

154 

155class UniqueCountsResult(NamedTuple): 

156 values: ndarray 

157 counts: ndarray 

158 

159 

160class UniqueInverseResult(NamedTuple): 

161 values: ndarray 

162 inverse_indices: ndarray 

163 

164 

165def _unique_kwargs(xp): 

166 # Older versions of NumPy and CuPy do not have equal_nan. Rather than 

167 # trying to parse version numbers, just check if equal_nan is in the 

168 # signature. 

169 s = inspect.signature(xp.unique) 

170 if 'equal_nan' in s.parameters: 

171 return {'equal_nan': False} 

172 return {} 

173 

174def unique_all(x: ndarray, /, xp) -> UniqueAllResult: 

175 kwargs = _unique_kwargs(xp) 

176 values, indices, inverse_indices, counts = xp.unique( 

177 x, 

178 return_counts=True, 

179 return_index=True, 

180 return_inverse=True, 

181 **kwargs, 

182 ) 

183 # np.unique() flattens inverse indices, but they need to share x's shape 

184 # See https://github.com/numpy/numpy/issues/20638 

185 inverse_indices = inverse_indices.reshape(x.shape) 

186 return UniqueAllResult( 

187 values, 

188 indices, 

189 inverse_indices, 

190 counts, 

191 ) 

192 

193 

194def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: 

195 kwargs = _unique_kwargs(xp) 

196 res = xp.unique( 

197 x, 

198 return_counts=True, 

199 return_index=False, 

200 return_inverse=False, 

201 **kwargs 

202 ) 

203 

204 return UniqueCountsResult(*res) 

205 

206 

207def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: 

208 kwargs = _unique_kwargs(xp) 

209 values, inverse_indices = xp.unique( 

210 x, 

211 return_counts=False, 

212 return_index=False, 

213 return_inverse=True, 

214 **kwargs, 

215 ) 

216 # xp.unique() flattens inverse indices, but they need to share x's shape 

217 # See https://github.com/numpy/numpy/issues/20638 

218 inverse_indices = inverse_indices.reshape(x.shape) 

219 return UniqueInverseResult(values, inverse_indices) 

220 

221 

222def unique_values(x: ndarray, /, xp) -> ndarray: 

223 kwargs = _unique_kwargs(xp) 

224 return xp.unique( 

225 x, 

226 return_counts=False, 

227 return_index=False, 

228 return_inverse=False, 

229 **kwargs, 

230 ) 

231 

232def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray: 

233 if not copy and dtype == x.dtype: 

234 return x 

235 return x.astype(dtype=dtype, copy=copy) 

236 

237# These functions have different keyword argument names 

238 

239def std( 

240 x: ndarray, 

241 /, 

242 xp, 

243 *, 

244 axis: Optional[Union[int, Tuple[int, ...]]] = None, 

245 correction: Union[int, float] = 0.0, # correction instead of ddof 

246 keepdims: bool = False, 

247 **kwargs, 

248) -> ndarray: 

249 return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) 

250 

251def var( 

252 x: ndarray, 

253 /, 

254 xp, 

255 *, 

256 axis: Optional[Union[int, Tuple[int, ...]]] = None, 

257 correction: Union[int, float] = 0.0, # correction instead of ddof 

258 keepdims: bool = False, 

259 **kwargs, 

260) -> ndarray: 

261 return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) 

262 

263# Unlike transpose(), the axes argument to permute_dims() is required. 

264def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: 

265 return xp.transpose(x, axes) 

266 

267# Creation functions add the device keyword (which does nothing for NumPy) 

268 

269# asarray also adds the copy keyword 

270def _asarray( 

271 obj: Union[ 

272 ndarray, 

273 bool, 

274 int, 

275 float, 

276 NestedSequence[bool | int | float], 

277 SupportsBufferProtocol, 

278 ], 

279 /, 

280 *, 

281 dtype: Optional[Dtype] = None, 

282 device: Optional[Device] = None, 

283 copy: "Optional[Union[bool, np._CopyMode]]" = None, 

284 namespace = None, 

285 **kwargs, 

286) -> ndarray: 

287 """ 

288 Array API compatibility wrapper for asarray(). 

289 

290 See the corresponding documentation in NumPy/CuPy and/or the array API 

291 specification for more details. 

292 

293 """ 

294 if namespace is None: 

295 try: 

296 xp = array_namespace(obj, _use_compat=False) 

297 except ValueError: 

298 # TODO: What about lists of arrays? 

299 raise ValueError("A namespace must be specified for asarray() with non-array input") 

300 elif isinstance(namespace, ModuleType): 

301 xp = namespace 

302 elif namespace == 'numpy': 

303 import numpy as xp 

304 elif namespace == 'cupy': 

305 import cupy as xp 

306 else: 

307 raise ValueError("Unrecognized namespace argument to asarray()") 

308 

309 _check_device(xp, device) 

310 if _is_numpy_array(obj): 

311 import numpy as np 

312 if hasattr(np, '_CopyMode'): 

313 # Not present in older NumPys 

314 COPY_FALSE = (False, np._CopyMode.IF_NEEDED) 

315 COPY_TRUE = (True, np._CopyMode.ALWAYS) 

316 else: 

317 COPY_FALSE = (False,) 

318 COPY_TRUE = (True,) 

319 else: 

320 COPY_FALSE = (False,) 

321 COPY_TRUE = (True,) 

322 if copy in COPY_FALSE: 

323 # copy=False is not yet implemented in xp.asarray 

324 raise NotImplementedError("copy=False is not yet implemented") 

325 if isinstance(obj, xp.ndarray): 

326 if dtype is not None and obj.dtype != dtype: 

327 copy = True 

328 if copy in COPY_TRUE: 

329 return xp.array(obj, copy=True, dtype=dtype) 

330 return obj 

331 

332 return xp.asarray(obj, dtype=dtype, **kwargs) 

333 

334# np.reshape calls the keyword argument 'newshape' instead of 'shape' 

335def reshape(x: ndarray, 

336 /, 

337 shape: Tuple[int, ...], 

338 xp, copy: Optional[bool] = None, 

339 **kwargs) -> ndarray: 

340 if copy is True: 

341 x = x.copy() 

342 elif copy is False: 

343 y = x.view() 

344 y.shape = shape 

345 return y 

346 return xp.reshape(x, shape, **kwargs) 

347 

348# The descending keyword is new in sort and argsort, and 'kind' replaced with 

349# 'stable' 

350def argsort( 

351 x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, 

352 **kwargs, 

353) -> ndarray: 

354 # Note: this keyword argument is different, and the default is different. 

355 # We set it in kwargs like this because numpy.sort uses kind='quicksort' 

356 # as the default whereas cupy.sort uses kind=None. 

357 if stable: 

358 kwargs['kind'] = "stable" 

359 if not descending: 

360 res = xp.argsort(x, axis=axis, **kwargs) 

361 else: 

362 # As NumPy has no native descending sort, we imitate it here. Note that 

363 # simply flipping the results of xp.argsort(x, ...) would not 

364 # respect the relative order like it would in native descending sorts. 

365 res = xp.flip( 

366 xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs), 

367 axis=axis, 

368 ) 

369 # Rely on flip()/argsort() to validate axis 

370 normalised_axis = axis if axis >= 0 else x.ndim + axis 

371 max_i = x.shape[normalised_axis] - 1 

372 res = max_i - res 

373 return res 

374 

375def sort( 

376 x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, 

377 **kwargs, 

378) -> ndarray: 

379 # Note: this keyword argument is different, and the default is different. 

380 # We set it in kwargs like this because numpy.sort uses kind='quicksort' 

381 # as the default whereas cupy.sort uses kind=None. 

382 if stable: 

383 kwargs['kind'] = "stable" 

384 res = xp.sort(x, axis=axis, **kwargs) 

385 if descending: 

386 res = xp.flip(res, axis=axis) 

387 return res 

388 

389# nonzero should error for zero-dimensional arrays 

390def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: 

391 if x.ndim == 0: 

392 raise ValueError("nonzero() does not support zero-dimensional arrays") 

393 return xp.nonzero(x, **kwargs) 

394 

395# sum() and prod() should always upcast when dtype=None 

396def sum( 

397 x: ndarray, 

398 /, 

399 xp, 

400 *, 

401 axis: Optional[Union[int, Tuple[int, ...]]] = None, 

402 dtype: Optional[Dtype] = None, 

403 keepdims: bool = False, 

404 **kwargs, 

405) -> ndarray: 

406 # `xp.sum` already upcasts integers, but not floats or complexes 

407 if dtype is None: 

408 if x.dtype == xp.float32: 

409 dtype = xp.float64 

410 elif x.dtype == xp.complex64: 

411 dtype = xp.complex128 

412 return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) 

413 

414def prod( 

415 x: ndarray, 

416 /, 

417 xp, 

418 *, 

419 axis: Optional[Union[int, Tuple[int, ...]]] = None, 

420 dtype: Optional[Dtype] = None, 

421 keepdims: bool = False, 

422 **kwargs, 

423) -> ndarray: 

424 if dtype is None: 

425 if x.dtype == xp.float32: 

426 dtype = xp.float64 

427 elif x.dtype == xp.complex64: 

428 dtype = xp.complex128 

429 return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs) 

430 

431# ceil, floor, and trunc return integers for integer inputs 

432 

433def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: 

434 if xp.issubdtype(x.dtype, xp.integer): 

435 return x 

436 return xp.ceil(x, **kwargs) 

437 

438def floor(x: ndarray, /, xp, **kwargs) -> ndarray: 

439 if xp.issubdtype(x.dtype, xp.integer): 

440 return x 

441 return xp.floor(x, **kwargs) 

442 

443def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: 

444 if xp.issubdtype(x.dtype, xp.integer): 

445 return x 

446 return xp.trunc(x, **kwargs) 

447 

448# linear algebra functions 

449 

450def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: 

451 return xp.matmul(x1, x2, **kwargs) 

452 

453# Unlike transpose, matrix_transpose only transposes the last two axes. 

454def matrix_transpose(x: ndarray, /, xp) -> ndarray: 

455 if x.ndim < 2: 

456 raise ValueError("x must be at least 2-dimensional for matrix_transpose") 

457 return xp.swapaxes(x, -1, -2) 

458 

459def tensordot(x1: ndarray, 

460 x2: ndarray, 

461 /, 

462 xp, 

463 *, 

464 axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, 

465 **kwargs, 

466) -> ndarray: 

467 return xp.tensordot(x1, x2, axes=axes, **kwargs) 

468 

469def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: 

470 ndim = max(x1.ndim, x2.ndim) 

471 x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) 

472 x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) 

473 if x1_shape[axis] != x2_shape[axis]: 

474 raise ValueError("x1 and x2 must have the same size along the given axis") 

475 

476 if hasattr(xp, 'broadcast_tensors'): 

477 _broadcast = xp.broadcast_tensors 

478 else: 

479 _broadcast = xp.broadcast_arrays 

480 

481 x1_, x2_ = _broadcast(x1, x2) 

482 x1_ = xp.moveaxis(x1_, axis, -1) 

483 x2_ = xp.moveaxis(x2_, axis, -1) 

484 

485 res = x1_[..., None, :] @ x2_[..., None] 

486 return res[..., 0, 0] 

487 

488# isdtype is a new function in the 2022.12 array API specification. 

489 

490def isdtype( 

491 dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, 

492 *, _tuple=True, # Disallow nested tuples 

493) -> bool: 

494 """ 

495 Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. 

496 

497 Note that outside of this function, this compat library does not yet fully 

498 support complex numbers. 

499 

500 See 

501 https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html 

502 for more details 

503 """ 

504 if isinstance(kind, tuple) and _tuple: 

505 return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) 

506 elif isinstance(kind, str): 

507 if kind == 'bool': 

508 return dtype == xp.bool_ 

509 elif kind == 'signed integer': 

510 return xp.issubdtype(dtype, xp.signedinteger) 

511 elif kind == 'unsigned integer': 

512 return xp.issubdtype(dtype, xp.unsignedinteger) 

513 elif kind == 'integral': 

514 return xp.issubdtype(dtype, xp.integer) 

515 elif kind == 'real floating': 

516 return xp.issubdtype(dtype, xp.floating) 

517 elif kind == 'complex floating': 

518 return xp.issubdtype(dtype, xp.complexfloating) 

519 elif kind == 'numeric': 

520 return xp.issubdtype(dtype, xp.number) 

521 else: 

522 raise ValueError(f"Unrecognized data type kind: {kind!r}") 

523 else: 

524 # This will allow things that aren't required by the spec, like 

525 # isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be 

526 # more strict here to match the type annotation? Note that the 

527 # numpy.array_api implementation will be very strict. 

528 return dtype == kind 

529 

530__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 

531 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 

532 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 

533 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 

534 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', 

535 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc', 

536 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']