Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scikit_learn-1.4.dev0-py3.8-linux-x86_64.egg/sklearn/utils/_array_api.py: 39%

219 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1"""Tools to support array_api.""" 

2import itertools 

3import math 

4from functools import wraps 

5 

6import numpy 

7import scipy.special as special 

8 

9from .._config import get_config 

10from .fixes import parse_version 

11 

12 

13def yield_namespace_device_dtype_combinations(): 

14 """Yield supported namespace, device, dtype tuples for testing. 

15 

16 Use this to test that an estimator works with all combinations. 

17 

18 Returns 

19 ------- 

20 array_namespace : str 

21 The name of the Array API namespace. 

22 

23 device : str 

24 The name of the device on which to allocate the arrays. Can be None to 

25 indicate that the default value should be used. 

26 

27 dtype : str 

28 The name of the data type to use for arrays. Can be None to indicate 

29 that the default value should be used. 

30 """ 

31 for array_namespace in [ 

32 # The following is used to test the array_api_compat wrapper when 

33 # array_api_dispatch is enabled: in particular, the arrays used in the 

34 # tests are regular numpy arrays without any "device" attribute. 

35 "numpy", 

36 # Stricter NumPy-based Array API implementation. The 

37 # numpy.array_api.Array instances always a dummy "device" attribute. 

38 "numpy.array_api", 

39 "cupy", 

40 "cupy.array_api", 

41 "torch", 

42 ]: 

43 if array_namespace == "torch": 

44 for device, dtype in itertools.product( 

45 ("cpu", "cuda"), ("float64", "float32") 

46 ): 

47 yield array_namespace, device, dtype 

48 yield array_namespace, "mps", "float32" 

49 else: 

50 yield array_namespace, None, None 

51 

52 

53def _check_array_api_dispatch(array_api_dispatch): 

54 """Check that array_api_compat is installed and NumPy version is compatible. 

55 

56 array_api_compat follows NEP29, which has a higher minimum NumPy version than 

57 scikit-learn. 

58 """ 

59 if array_api_dispatch: 

60 try: 

61 import array_api_compat # noqa 

62 except ImportError: 

63 raise ImportError( 

64 "array_api_compat is required to dispatch arrays using the API" 

65 " specification" 

66 ) 

67 

68 numpy_version = parse_version(numpy.__version__) 

69 min_numpy_version = "1.21" 

70 if numpy_version < parse_version(min_numpy_version): 

71 raise ImportError( 

72 f"NumPy must be {min_numpy_version} or newer to dispatch array using" 

73 " the API specification" 

74 ) 

75 

76 

77def device(x): 

78 """Hardware device the array data resides on. 

79 

80 Parameters 

81 ---------- 

82 x : array 

83 Array instance from NumPy or an array API compatible library. 

84 

85 Returns 

86 ------- 

87 out : device 

88 `device` object (see the "Device Support" section of the array API spec). 

89 """ 

90 if isinstance(x, (numpy.ndarray, numpy.generic)): 

91 return "cpu" 

92 return x.device 

93 

94 

95def size(x): 

96 """Return the total number of elements of x. 

97 

98 Parameters 

99 ---------- 

100 x : array 

101 Array instance from NumPy or an array API compatible library. 

102 

103 Returns 

104 ------- 

105 out : int 

106 Total number of elements. 

107 """ 

108 return math.prod(x.shape) 

109 

110 

111def _is_numpy_namespace(xp): 

112 """Return True if xp is backed by NumPy.""" 

113 return xp.__name__ in {"numpy", "array_api_compat.numpy", "numpy.array_api"} 

114 

115 

116def _union1d(a, b, xp): 

117 if _is_numpy_namespace(xp): 

118 return xp.asarray(numpy.union1d(a, b)) 

119 assert a.ndim == b.ndim == 1 

120 return xp.unique_values(xp.concat([xp.unique_values(a), xp.unique_values(b)])) 

121 

122 

123def isdtype(dtype, kind, *, xp): 

124 """Returns a boolean indicating whether a provided dtype is of type "kind". 

125 

126 Included in the v2022.12 of the Array API spec. 

127 https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html 

128 """ 

129 if isinstance(kind, tuple): 

130 return any(_isdtype_single(dtype, k, xp=xp) for k in kind) 

131 else: 

132 return _isdtype_single(dtype, kind, xp=xp) 

133 

134 

135def _isdtype_single(dtype, kind, *, xp): 

136 if isinstance(kind, str): 

137 if kind == "bool": 

138 return dtype == xp.bool 

139 elif kind == "signed integer": 

140 return dtype in {xp.int8, xp.int16, xp.int32, xp.int64} 

141 elif kind == "unsigned integer": 

142 return dtype in {xp.uint8, xp.uint16, xp.uint32, xp.uint64} 

143 elif kind == "integral": 

144 return any( 

145 _isdtype_single(dtype, k, xp=xp) 

146 for k in ("signed integer", "unsigned integer") 

147 ) 

148 elif kind == "real floating": 

149 return dtype in supported_float_dtypes(xp) 

150 elif kind == "complex floating": 

151 # Some name spaces do not have complex, such as cupy.array_api 

152 # and numpy.array_api 

153 complex_dtypes = set() 

154 if hasattr(xp, "complex64"): 

155 complex_dtypes.add(xp.complex64) 

156 if hasattr(xp, "complex128"): 

157 complex_dtypes.add(xp.complex128) 

158 return dtype in complex_dtypes 

159 elif kind == "numeric": 

160 return any( 

161 _isdtype_single(dtype, k, xp=xp) 

162 for k in ("integral", "real floating", "complex floating") 

163 ) 

164 else: 

165 raise ValueError(f"Unrecognized data type kind: {kind!r}") 

166 else: 

167 return dtype == kind 

168 

169 

170def supported_float_dtypes(xp): 

171 """Supported floating point types for the namespace 

172 

173 Note: float16 is not officially part of the Array API spec at the 

174 time of writing but scikit-learn estimators and functions can choose 

175 to accept it when xp.float16 is defined. 

176 

177 https://data-apis.org/array-api/latest/API_specification/data_types.html 

178 """ 

179 if hasattr(xp, "float16"): 

180 return (xp.float64, xp.float32, xp.float16) 

181 else: 

182 return (xp.float64, xp.float32) 

183 

184 

185class _ArrayAPIWrapper: 

186 """sklearn specific Array API compatibility wrapper 

187 

188 This wrapper makes it possible for scikit-learn maintainers to 

189 deal with discrepancies between different implementations of the 

190 Python Array API standard and its evolution over time. 

191 

192 The Python Array API standard specification: 

193 https://data-apis.org/array-api/latest/ 

194 

195 Documentation of the NumPy implementation: 

196 https://numpy.org/neps/nep-0047-array-api-standard.html 

197 """ 

198 

199 def __init__(self, array_namespace): 

200 self._namespace = array_namespace 

201 

202 def __getattr__(self, name): 

203 return getattr(self._namespace, name) 

204 

205 def __eq__(self, other): 

206 return self._namespace == other._namespace 

207 

208 def isdtype(self, dtype, kind): 

209 return isdtype(dtype, kind, xp=self._namespace) 

210 

211 

212def _check_device_cpu(device): # noqa 

213 if device not in {"cpu", None}: 

214 raise ValueError(f"Unsupported device for NumPy: {device!r}") 

215 

216 

217def _accept_device_cpu(func): 

218 @wraps(func) 

219 def wrapped_func(*args, **kwargs): 

220 _check_device_cpu(kwargs.pop("device", None)) 

221 return func(*args, **kwargs) 

222 

223 return wrapped_func 

224 

225 

226class _NumPyAPIWrapper: 

227 """Array API compat wrapper for any numpy version 

228 

229 NumPy < 1.22 does not expose the numpy.array_api namespace. This 

230 wrapper makes it possible to write code that uses the standard 

231 Array API while working with any version of NumPy supported by 

232 scikit-learn. 

233 

234 See the `get_namespace()` public function for more details. 

235 """ 

236 

237 # Creation functions in spec: 

238 # https://data-apis.org/array-api/latest/API_specification/creation_functions.html 

239 _CREATION_FUNCS = { 

240 "arange", 

241 "empty", 

242 "empty_like", 

243 "eye", 

244 "full", 

245 "full_like", 

246 "linspace", 

247 "ones", 

248 "ones_like", 

249 "zeros", 

250 "zeros_like", 

251 } 

252 # Data types in spec 

253 # https://data-apis.org/array-api/latest/API_specification/data_types.html 

254 _DTYPES = { 

255 "int8", 

256 "int16", 

257 "int32", 

258 "int64", 

259 "uint8", 

260 "uint16", 

261 "uint32", 

262 "uint64", 

263 # XXX: float16 is not part of the Array API spec but exposed by 

264 # some namespaces. 

265 "float16", 

266 "float32", 

267 "float64", 

268 "complex64", 

269 "complex128", 

270 } 

271 

272 def __getattr__(self, name): 

273 attr = getattr(numpy, name) 

274 

275 # Support device kwargs and make sure they are on the CPU 

276 if name in self._CREATION_FUNCS: 

277 return _accept_device_cpu(attr) 

278 

279 # Convert to dtype objects 

280 if name in self._DTYPES: 

281 return numpy.dtype(attr) 

282 return attr 

283 

284 @property 

285 def bool(self): 

286 return numpy.bool_ 

287 

288 def astype(self, x, dtype, *, copy=True, casting="unsafe"): 

289 # astype is not defined in the top level NumPy namespace 

290 return x.astype(dtype, copy=copy, casting=casting) 

291 

292 def asarray(self, x, *, dtype=None, device=None, copy=None): # noqa 

293 _check_device_cpu(device) 

294 # Support copy in NumPy namespace 

295 if copy is True: 

296 return numpy.array(x, copy=True, dtype=dtype) 

297 else: 

298 return numpy.asarray(x, dtype=dtype) 

299 

300 def unique_inverse(self, x): 

301 return numpy.unique(x, return_inverse=True) 

302 

303 def unique_counts(self, x): 

304 return numpy.unique(x, return_counts=True) 

305 

306 def unique_values(self, x): 

307 return numpy.unique(x) 

308 

309 def concat(self, arrays, *, axis=None): 

310 return numpy.concatenate(arrays, axis=axis) 

311 

312 def reshape(self, x, shape, *, copy=None): 

313 """Gives a new shape to an array without changing its data. 

314 

315 The Array API specification requires shape to be a tuple. 

316 https://data-apis.org/array-api/latest/API_specification/generated/array_api.reshape.html 

317 """ 

318 if not isinstance(shape, tuple): 

319 raise TypeError( 

320 f"shape must be a tuple, got {shape!r} of type {type(shape)}" 

321 ) 

322 

323 if copy is True: 

324 x = x.copy() 

325 return numpy.reshape(x, shape) 

326 

327 def isdtype(self, dtype, kind): 

328 return isdtype(dtype, kind, xp=self) 

329 

330 

331_NUMPY_API_WRAPPER_INSTANCE = _NumPyAPIWrapper() 

332 

333 

334def get_namespace(*arrays): 

335 """Get namespace of arrays. 

336 

337 Introspect `arrays` arguments and return their common Array API 

338 compatible namespace object, if any. NumPy 1.22 and later can 

339 construct such containers using the `numpy.array_api` namespace 

340 for instance. 

341 

342 See: https://numpy.org/neps/nep-0047-array-api-standard.html 

343 

344 If `arrays` are regular numpy arrays, an instance of the 

345 `_NumPyAPIWrapper` compatibility wrapper is returned instead. 

346 

347 Namespace support is not enabled by default. To enabled it 

348 call: 

349 

350 sklearn.set_config(array_api_dispatch=True) 

351 

352 or: 

353 

354 with sklearn.config_context(array_api_dispatch=True): 

355 # your code here 

356 

357 Otherwise an instance of the `_NumPyAPIWrapper` 

358 compatibility wrapper is always returned irrespective of 

359 the fact that arrays implement the `__array_namespace__` 

360 protocol or not. 

361 

362 Parameters 

363 ---------- 

364 *arrays : array objects 

365 Array objects. 

366 

367 Returns 

368 ------- 

369 namespace : module 

370 Namespace shared by array objects. If any of the `arrays` are not arrays, 

371 the namespace defaults to NumPy. 

372 

373 is_array_api_compliant : bool 

374 True if the arrays are containers that implement the Array API spec. 

375 Always False when array_api_dispatch=False. 

376 """ 

377 array_api_dispatch = get_config()["array_api_dispatch"] 

378 if not array_api_dispatch: 

379 return _NUMPY_API_WRAPPER_INSTANCE, False 

380 

381 _check_array_api_dispatch(array_api_dispatch) 

382 

383 # array-api-compat is a required dependency of scikit-learn only when 

384 # configuring `array_api_dispatch=True`. Its import should therefore be 

385 # protected by _check_array_api_dispatch to display an informative error 

386 # message in case it is missing. 

387 import array_api_compat 

388 

389 namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True 

390 

391 # These namespaces need additional wrapping to smooth out small differences 

392 # between implementations 

393 if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}: 

394 namespace = _ArrayAPIWrapper(namespace) 

395 

396 return namespace, is_array_api_compliant 

397 

398 

399def _expit(X): 

400 xp, _ = get_namespace(X) 

401 if _is_numpy_namespace(xp): 

402 return xp.asarray(special.expit(numpy.asarray(X))) 

403 

404 return 1.0 / (1.0 + xp.exp(-X)) 

405 

406 

407def _add_to_diagonal(array, value, xp): 

408 # Workaround for the lack of support for xp.reshape(a, shape, copy=False) in 

409 # numpy.array_api: https://github.com/numpy/numpy/issues/23410 

410 value = xp.asarray(value, dtype=array.dtype) 

411 if _is_numpy_namespace(xp): 

412 array_np = numpy.asarray(array) 

413 array_np.flat[:: array.shape[0] + 1] += value 

414 return xp.asarray(array_np) 

415 elif value.ndim == 1: 

416 for i in range(array.shape[0]): 

417 array[i, i] += value[i] 

418 else: 

419 # scalar value 

420 for i in range(array.shape[0]): 

421 array[i, i] += value 

422 

423 

424def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None): 

425 # XXX: this function accepts Array API input but returns a Python scalar 

426 # float. The call to float() is convenient because it removes the need to 

427 # move back results from device to host memory (e.g. calling `.cpu()` on a 

428 # torch tensor). However, this might interact in unexpected ways (break?) 

429 # with lazy Array API implementations. See: 

430 # https://github.com/data-apis/array-api/issues/642 

431 if xp is None: 

432 xp, _ = get_namespace(sample_score) 

433 if normalize and _is_numpy_namespace(xp): 

434 sample_score_np = numpy.asarray(sample_score) 

435 if sample_weight is not None: 

436 sample_weight_np = numpy.asarray(sample_weight) 

437 else: 

438 sample_weight_np = None 

439 return float(numpy.average(sample_score_np, weights=sample_weight_np)) 

440 

441 if not xp.isdtype(sample_score.dtype, "real floating"): 

442 # We move to cpu device ahead of time since certain devices may not support 

443 # float64, but we want the same precision for all devices and namespaces. 

444 sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64) 

445 

446 if sample_weight is not None: 

447 sample_weight = xp.asarray(sample_weight, dtype=sample_score.dtype) 

448 if not xp.isdtype(sample_weight.dtype, "real floating"): 

449 sample_weight = xp.astype(sample_weight, xp.float64) 

450 

451 if normalize: 

452 if sample_weight is not None: 

453 scale = xp.sum(sample_weight) 

454 else: 

455 scale = sample_score.shape[0] 

456 if scale != 0: 

457 sample_score = sample_score / scale 

458 

459 if sample_weight is not None: 

460 return float(sample_score @ sample_weight) 

461 else: 

462 return float(xp.sum(sample_score)) 

463 

464 

465def _nanmin(X, axis=None): 

466 # TODO: refactor once nan-aware reductions are standardized: 

467 # https://github.com/data-apis/array-api/issues/621 

468 xp, _ = get_namespace(X) 

469 if _is_numpy_namespace(xp): 

470 return xp.asarray(numpy.nanmin(X, axis=axis)) 

471 

472 else: 

473 mask = xp.isnan(X) 

474 X = xp.min(xp.where(mask, xp.asarray(+xp.inf, device=device(X)), X), axis=axis) 

475 # Replace Infs from all NaN slices with NaN again 

476 mask = xp.all(mask, axis=axis) 

477 if xp.any(mask): 

478 X = xp.where(mask, xp.asarray(xp.nan), X) 

479 return X 

480 

481 

482def _nanmax(X, axis=None): 

483 # TODO: refactor once nan-aware reductions are standardized: 

484 # https://github.com/data-apis/array-api/issues/621 

485 xp, _ = get_namespace(X) 

486 if _is_numpy_namespace(xp): 

487 return xp.asarray(numpy.nanmax(X, axis=axis)) 

488 

489 else: 

490 mask = xp.isnan(X) 

491 X = xp.max(xp.where(mask, xp.asarray(-xp.inf, device=device(X)), X), axis=axis) 

492 # Replace Infs from all NaN slices with NaN again 

493 mask = xp.all(mask, axis=axis) 

494 if xp.any(mask): 

495 X = xp.where(mask, xp.asarray(xp.nan), X) 

496 return X 

497 

498 

499def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None): 

500 """Helper to support the order kwarg only for NumPy-backed arrays 

501 

502 Memory layout parameter `order` is not exposed in the Array API standard, 

503 however some input validation code in scikit-learn needs to work both 

504 for classes and functions that will leverage Array API only operations 

505 and for code that inherently relies on NumPy backed data containers with 

506 specific memory layout constraints (e.g. our own Cython code). The 

507 purpose of this helper is to make it possible to share code for data 

508 container validation without memory copies for both downstream use cases: 

509 the `order` parameter is only enforced if the input array implementation 

510 is NumPy based, otherwise `order` is just silently ignored. 

511 """ 

512 if xp is None: 

513 xp, _ = get_namespace(array) 

514 if _is_numpy_namespace(xp): 

515 # Use NumPy API to support order 

516 if copy is True: 

517 array = numpy.array(array, order=order, dtype=dtype) 

518 else: 

519 array = numpy.asarray(array, order=order, dtype=dtype) 

520 

521 # At this point array is a NumPy ndarray. We convert it to an array 

522 # container that is consistent with the input's namespace. 

523 return xp.asarray(array) 

524 else: 

525 return xp.asarray(array, dtype=dtype, copy=copy) 

526 

527 

528def _convert_to_numpy(array, xp): 

529 """Convert X into a NumPy ndarray on the CPU.""" 

530 xp_name = xp.__name__ 

531 

532 if xp_name in {"array_api_compat.torch", "torch"}: 

533 return array.cpu().numpy() 

534 elif xp_name == "cupy.array_api": 

535 return array._array.get() 

536 elif xp_name in {"array_api_compat.cupy", "cupy"}: # pragma: nocover 

537 return array.get() 

538 

539 return numpy.asarray(array) 

540 

541 

542def _estimator_with_converted_arrays(estimator, converter): 

543 """Create new estimator which converting all attributes that are arrays. 

544 

545 The converter is called on all NumPy arrays and arrays that support the 

546 `DLPack interface <https://dmlc.github.io/dlpack/latest/>`__. 

547 

548 Parameters 

549 ---------- 

550 estimator : Estimator 

551 Estimator to convert 

552 

553 converter : callable 

554 Callable that takes an array attribute and returns the converted array. 

555 

556 Returns 

557 ------- 

558 new_estimator : Estimator 

559 Convert estimator 

560 """ 

561 from sklearn.base import clone 

562 

563 new_estimator = clone(estimator) 

564 for key, attribute in vars(estimator).items(): 

565 if hasattr(attribute, "__dlpack__") or isinstance(attribute, numpy.ndarray): 

566 attribute = converter(attribute) 

567 setattr(new_estimator, key, attribute) 

568 return new_estimator 

569 

570 

571def _atol_for_type(dtype): 

572 """Return the absolute tolerance for a given dtype.""" 

573 return numpy.finfo(dtype).eps * 100