1"""Tools to support array_api."""
2import itertools
3import math
4from functools import wraps
5
6import numpy
7import scipy.special as special
8
9from .._config import get_config
10from .fixes import parse_version
11
12
13def yield_namespace_device_dtype_combinations():
14 """Yield supported namespace, device, dtype tuples for testing.
15
16 Use this to test that an estimator works with all combinations.
17
18 Returns
19 -------
20 array_namespace : str
21 The name of the Array API namespace.
22
23 device : str
24 The name of the device on which to allocate the arrays. Can be None to
25 indicate that the default value should be used.
26
27 dtype : str
28 The name of the data type to use for arrays. Can be None to indicate
29 that the default value should be used.
30 """
31 for array_namespace in [
32 # The following is used to test the array_api_compat wrapper when
33 # array_api_dispatch is enabled: in particular, the arrays used in the
34 # tests are regular numpy arrays without any "device" attribute.
35 "numpy",
36 # Stricter NumPy-based Array API implementation. The
37 # numpy.array_api.Array instances always a dummy "device" attribute.
38 "numpy.array_api",
39 "cupy",
40 "cupy.array_api",
41 "torch",
42 ]:
43 if array_namespace == "torch":
44 for device, dtype in itertools.product(
45 ("cpu", "cuda"), ("float64", "float32")
46 ):
47 yield array_namespace, device, dtype
48 yield array_namespace, "mps", "float32"
49 else:
50 yield array_namespace, None, None
51
52
53def _check_array_api_dispatch(array_api_dispatch):
54 """Check that array_api_compat is installed and NumPy version is compatible.
55
56 array_api_compat follows NEP29, which has a higher minimum NumPy version than
57 scikit-learn.
58 """
59 if array_api_dispatch:
60 try:
61 import array_api_compat # noqa
62 except ImportError:
63 raise ImportError(
64 "array_api_compat is required to dispatch arrays using the API"
65 " specification"
66 )
67
68 numpy_version = parse_version(numpy.__version__)
69 min_numpy_version = "1.21"
70 if numpy_version < parse_version(min_numpy_version):
71 raise ImportError(
72 f"NumPy must be {min_numpy_version} or newer to dispatch array using"
73 " the API specification"
74 )
75
76
77def device(x):
78 """Hardware device the array data resides on.
79
80 Parameters
81 ----------
82 x : array
83 Array instance from NumPy or an array API compatible library.
84
85 Returns
86 -------
87 out : device
88 `device` object (see the "Device Support" section of the array API spec).
89 """
90 if isinstance(x, (numpy.ndarray, numpy.generic)):
91 return "cpu"
92 return x.device
93
94
95def size(x):
96 """Return the total number of elements of x.
97
98 Parameters
99 ----------
100 x : array
101 Array instance from NumPy or an array API compatible library.
102
103 Returns
104 -------
105 out : int
106 Total number of elements.
107 """
108 return math.prod(x.shape)
109
110
111def _is_numpy_namespace(xp):
112 """Return True if xp is backed by NumPy."""
113 return xp.__name__ in {"numpy", "array_api_compat.numpy", "numpy.array_api"}
114
115
116def _union1d(a, b, xp):
117 if _is_numpy_namespace(xp):
118 return xp.asarray(numpy.union1d(a, b))
119 assert a.ndim == b.ndim == 1
120 return xp.unique_values(xp.concat([xp.unique_values(a), xp.unique_values(b)]))
121
122
123def isdtype(dtype, kind, *, xp):
124 """Returns a boolean indicating whether a provided dtype is of type "kind".
125
126 Included in the v2022.12 of the Array API spec.
127 https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
128 """
129 if isinstance(kind, tuple):
130 return any(_isdtype_single(dtype, k, xp=xp) for k in kind)
131 else:
132 return _isdtype_single(dtype, kind, xp=xp)
133
134
135def _isdtype_single(dtype, kind, *, xp):
136 if isinstance(kind, str):
137 if kind == "bool":
138 return dtype == xp.bool
139 elif kind == "signed integer":
140 return dtype in {xp.int8, xp.int16, xp.int32, xp.int64}
141 elif kind == "unsigned integer":
142 return dtype in {xp.uint8, xp.uint16, xp.uint32, xp.uint64}
143 elif kind == "integral":
144 return any(
145 _isdtype_single(dtype, k, xp=xp)
146 for k in ("signed integer", "unsigned integer")
147 )
148 elif kind == "real floating":
149 return dtype in supported_float_dtypes(xp)
150 elif kind == "complex floating":
151 # Some name spaces do not have complex, such as cupy.array_api
152 # and numpy.array_api
153 complex_dtypes = set()
154 if hasattr(xp, "complex64"):
155 complex_dtypes.add(xp.complex64)
156 if hasattr(xp, "complex128"):
157 complex_dtypes.add(xp.complex128)
158 return dtype in complex_dtypes
159 elif kind == "numeric":
160 return any(
161 _isdtype_single(dtype, k, xp=xp)
162 for k in ("integral", "real floating", "complex floating")
163 )
164 else:
165 raise ValueError(f"Unrecognized data type kind: {kind!r}")
166 else:
167 return dtype == kind
168
169
170def supported_float_dtypes(xp):
171 """Supported floating point types for the namespace
172
173 Note: float16 is not officially part of the Array API spec at the
174 time of writing but scikit-learn estimators and functions can choose
175 to accept it when xp.float16 is defined.
176
177 https://data-apis.org/array-api/latest/API_specification/data_types.html
178 """
179 if hasattr(xp, "float16"):
180 return (xp.float64, xp.float32, xp.float16)
181 else:
182 return (xp.float64, xp.float32)
183
184
185class _ArrayAPIWrapper:
186 """sklearn specific Array API compatibility wrapper
187
188 This wrapper makes it possible for scikit-learn maintainers to
189 deal with discrepancies between different implementations of the
190 Python Array API standard and its evolution over time.
191
192 The Python Array API standard specification:
193 https://data-apis.org/array-api/latest/
194
195 Documentation of the NumPy implementation:
196 https://numpy.org/neps/nep-0047-array-api-standard.html
197 """
198
199 def __init__(self, array_namespace):
200 self._namespace = array_namespace
201
202 def __getattr__(self, name):
203 return getattr(self._namespace, name)
204
205 def __eq__(self, other):
206 return self._namespace == other._namespace
207
208 def isdtype(self, dtype, kind):
209 return isdtype(dtype, kind, xp=self._namespace)
210
211
212def _check_device_cpu(device): # noqa
213 if device not in {"cpu", None}:
214 raise ValueError(f"Unsupported device for NumPy: {device!r}")
215
216
217def _accept_device_cpu(func):
218 @wraps(func)
219 def wrapped_func(*args, **kwargs):
220 _check_device_cpu(kwargs.pop("device", None))
221 return func(*args, **kwargs)
222
223 return wrapped_func
224
225
226class _NumPyAPIWrapper:
227 """Array API compat wrapper for any numpy version
228
229 NumPy < 1.22 does not expose the numpy.array_api namespace. This
230 wrapper makes it possible to write code that uses the standard
231 Array API while working with any version of NumPy supported by
232 scikit-learn.
233
234 See the `get_namespace()` public function for more details.
235 """
236
237 # Creation functions in spec:
238 # https://data-apis.org/array-api/latest/API_specification/creation_functions.html
239 _CREATION_FUNCS = {
240 "arange",
241 "empty",
242 "empty_like",
243 "eye",
244 "full",
245 "full_like",
246 "linspace",
247 "ones",
248 "ones_like",
249 "zeros",
250 "zeros_like",
251 }
252 # Data types in spec
253 # https://data-apis.org/array-api/latest/API_specification/data_types.html
254 _DTYPES = {
255 "int8",
256 "int16",
257 "int32",
258 "int64",
259 "uint8",
260 "uint16",
261 "uint32",
262 "uint64",
263 # XXX: float16 is not part of the Array API spec but exposed by
264 # some namespaces.
265 "float16",
266 "float32",
267 "float64",
268 "complex64",
269 "complex128",
270 }
271
272 def __getattr__(self, name):
273 attr = getattr(numpy, name)
274
275 # Support device kwargs and make sure they are on the CPU
276 if name in self._CREATION_FUNCS:
277 return _accept_device_cpu(attr)
278
279 # Convert to dtype objects
280 if name in self._DTYPES:
281 return numpy.dtype(attr)
282 return attr
283
284 @property
285 def bool(self):
286 return numpy.bool_
287
288 def astype(self, x, dtype, *, copy=True, casting="unsafe"):
289 # astype is not defined in the top level NumPy namespace
290 return x.astype(dtype, copy=copy, casting=casting)
291
292 def asarray(self, x, *, dtype=None, device=None, copy=None): # noqa
293 _check_device_cpu(device)
294 # Support copy in NumPy namespace
295 if copy is True:
296 return numpy.array(x, copy=True, dtype=dtype)
297 else:
298 return numpy.asarray(x, dtype=dtype)
299
300 def unique_inverse(self, x):
301 return numpy.unique(x, return_inverse=True)
302
303 def unique_counts(self, x):
304 return numpy.unique(x, return_counts=True)
305
306 def unique_values(self, x):
307 return numpy.unique(x)
308
309 def concat(self, arrays, *, axis=None):
310 return numpy.concatenate(arrays, axis=axis)
311
312 def reshape(self, x, shape, *, copy=None):
313 """Gives a new shape to an array without changing its data.
314
315 The Array API specification requires shape to be a tuple.
316 https://data-apis.org/array-api/latest/API_specification/generated/array_api.reshape.html
317 """
318 if not isinstance(shape, tuple):
319 raise TypeError(
320 f"shape must be a tuple, got {shape!r} of type {type(shape)}"
321 )
322
323 if copy is True:
324 x = x.copy()
325 return numpy.reshape(x, shape)
326
327 def isdtype(self, dtype, kind):
328 return isdtype(dtype, kind, xp=self)
329
330
331_NUMPY_API_WRAPPER_INSTANCE = _NumPyAPIWrapper()
332
333
334def get_namespace(*arrays):
335 """Get namespace of arrays.
336
337 Introspect `arrays` arguments and return their common Array API
338 compatible namespace object, if any. NumPy 1.22 and later can
339 construct such containers using the `numpy.array_api` namespace
340 for instance.
341
342 See: https://numpy.org/neps/nep-0047-array-api-standard.html
343
344 If `arrays` are regular numpy arrays, an instance of the
345 `_NumPyAPIWrapper` compatibility wrapper is returned instead.
346
347 Namespace support is not enabled by default. To enabled it
348 call:
349
350 sklearn.set_config(array_api_dispatch=True)
351
352 or:
353
354 with sklearn.config_context(array_api_dispatch=True):
355 # your code here
356
357 Otherwise an instance of the `_NumPyAPIWrapper`
358 compatibility wrapper is always returned irrespective of
359 the fact that arrays implement the `__array_namespace__`
360 protocol or not.
361
362 Parameters
363 ----------
364 *arrays : array objects
365 Array objects.
366
367 Returns
368 -------
369 namespace : module
370 Namespace shared by array objects. If any of the `arrays` are not arrays,
371 the namespace defaults to NumPy.
372
373 is_array_api_compliant : bool
374 True if the arrays are containers that implement the Array API spec.
375 Always False when array_api_dispatch=False.
376 """
377 array_api_dispatch = get_config()["array_api_dispatch"]
378 if not array_api_dispatch:
379 return _NUMPY_API_WRAPPER_INSTANCE, False
380
381 _check_array_api_dispatch(array_api_dispatch)
382
383 # array-api-compat is a required dependency of scikit-learn only when
384 # configuring `array_api_dispatch=True`. Its import should therefore be
385 # protected by _check_array_api_dispatch to display an informative error
386 # message in case it is missing.
387 import array_api_compat
388
389 namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True
390
391 # These namespaces need additional wrapping to smooth out small differences
392 # between implementations
393 if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}:
394 namespace = _ArrayAPIWrapper(namespace)
395
396 return namespace, is_array_api_compliant
397
398
399def _expit(X):
400 xp, _ = get_namespace(X)
401 if _is_numpy_namespace(xp):
402 return xp.asarray(special.expit(numpy.asarray(X)))
403
404 return 1.0 / (1.0 + xp.exp(-X))
405
406
407def _add_to_diagonal(array, value, xp):
408 # Workaround for the lack of support for xp.reshape(a, shape, copy=False) in
409 # numpy.array_api: https://github.com/numpy/numpy/issues/23410
410 value = xp.asarray(value, dtype=array.dtype)
411 if _is_numpy_namespace(xp):
412 array_np = numpy.asarray(array)
413 array_np.flat[:: array.shape[0] + 1] += value
414 return xp.asarray(array_np)
415 elif value.ndim == 1:
416 for i in range(array.shape[0]):
417 array[i, i] += value[i]
418 else:
419 # scalar value
420 for i in range(array.shape[0]):
421 array[i, i] += value
422
423
424def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
425 # XXX: this function accepts Array API input but returns a Python scalar
426 # float. The call to float() is convenient because it removes the need to
427 # move back results from device to host memory (e.g. calling `.cpu()` on a
428 # torch tensor). However, this might interact in unexpected ways (break?)
429 # with lazy Array API implementations. See:
430 # https://github.com/data-apis/array-api/issues/642
431 if xp is None:
432 xp, _ = get_namespace(sample_score)
433 if normalize and _is_numpy_namespace(xp):
434 sample_score_np = numpy.asarray(sample_score)
435 if sample_weight is not None:
436 sample_weight_np = numpy.asarray(sample_weight)
437 else:
438 sample_weight_np = None
439 return float(numpy.average(sample_score_np, weights=sample_weight_np))
440
441 if not xp.isdtype(sample_score.dtype, "real floating"):
442 # We move to cpu device ahead of time since certain devices may not support
443 # float64, but we want the same precision for all devices and namespaces.
444 sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64)
445
446 if sample_weight is not None:
447 sample_weight = xp.asarray(sample_weight, dtype=sample_score.dtype)
448 if not xp.isdtype(sample_weight.dtype, "real floating"):
449 sample_weight = xp.astype(sample_weight, xp.float64)
450
451 if normalize:
452 if sample_weight is not None:
453 scale = xp.sum(sample_weight)
454 else:
455 scale = sample_score.shape[0]
456 if scale != 0:
457 sample_score = sample_score / scale
458
459 if sample_weight is not None:
460 return float(sample_score @ sample_weight)
461 else:
462 return float(xp.sum(sample_score))
463
464
465def _nanmin(X, axis=None):
466 # TODO: refactor once nan-aware reductions are standardized:
467 # https://github.com/data-apis/array-api/issues/621
468 xp, _ = get_namespace(X)
469 if _is_numpy_namespace(xp):
470 return xp.asarray(numpy.nanmin(X, axis=axis))
471
472 else:
473 mask = xp.isnan(X)
474 X = xp.min(xp.where(mask, xp.asarray(+xp.inf, device=device(X)), X), axis=axis)
475 # Replace Infs from all NaN slices with NaN again
476 mask = xp.all(mask, axis=axis)
477 if xp.any(mask):
478 X = xp.where(mask, xp.asarray(xp.nan), X)
479 return X
480
481
482def _nanmax(X, axis=None):
483 # TODO: refactor once nan-aware reductions are standardized:
484 # https://github.com/data-apis/array-api/issues/621
485 xp, _ = get_namespace(X)
486 if _is_numpy_namespace(xp):
487 return xp.asarray(numpy.nanmax(X, axis=axis))
488
489 else:
490 mask = xp.isnan(X)
491 X = xp.max(xp.where(mask, xp.asarray(-xp.inf, device=device(X)), X), axis=axis)
492 # Replace Infs from all NaN slices with NaN again
493 mask = xp.all(mask, axis=axis)
494 if xp.any(mask):
495 X = xp.where(mask, xp.asarray(xp.nan), X)
496 return X
497
498
499def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None):
500 """Helper to support the order kwarg only for NumPy-backed arrays
501
502 Memory layout parameter `order` is not exposed in the Array API standard,
503 however some input validation code in scikit-learn needs to work both
504 for classes and functions that will leverage Array API only operations
505 and for code that inherently relies on NumPy backed data containers with
506 specific memory layout constraints (e.g. our own Cython code). The
507 purpose of this helper is to make it possible to share code for data
508 container validation without memory copies for both downstream use cases:
509 the `order` parameter is only enforced if the input array implementation
510 is NumPy based, otherwise `order` is just silently ignored.
511 """
512 if xp is None:
513 xp, _ = get_namespace(array)
514 if _is_numpy_namespace(xp):
515 # Use NumPy API to support order
516 if copy is True:
517 array = numpy.array(array, order=order, dtype=dtype)
518 else:
519 array = numpy.asarray(array, order=order, dtype=dtype)
520
521 # At this point array is a NumPy ndarray. We convert it to an array
522 # container that is consistent with the input's namespace.
523 return xp.asarray(array)
524 else:
525 return xp.asarray(array, dtype=dtype, copy=copy)
526
527
528def _convert_to_numpy(array, xp):
529 """Convert X into a NumPy ndarray on the CPU."""
530 xp_name = xp.__name__
531
532 if xp_name in {"array_api_compat.torch", "torch"}:
533 return array.cpu().numpy()
534 elif xp_name == "cupy.array_api":
535 return array._array.get()
536 elif xp_name in {"array_api_compat.cupy", "cupy"}: # pragma: nocover
537 return array.get()
538
539 return numpy.asarray(array)
540
541
542def _estimator_with_converted_arrays(estimator, converter):
543 """Create new estimator which converting all attributes that are arrays.
544
545 The converter is called on all NumPy arrays and arrays that support the
546 `DLPack interface <https://dmlc.github.io/dlpack/latest/>`__.
547
548 Parameters
549 ----------
550 estimator : Estimator
551 Estimator to convert
552
553 converter : callable
554 Callable that takes an array attribute and returns the converted array.
555
556 Returns
557 -------
558 new_estimator : Estimator
559 Convert estimator
560 """
561 from sklearn.base import clone
562
563 new_estimator = clone(estimator)
564 for key, attribute in vars(estimator).items():
565 if hasattr(attribute, "__dlpack__") or isinstance(attribute, numpy.ndarray):
566 attribute = converter(attribute)
567 setattr(new_estimator, key, attribute)
568 return new_estimator
569
570
571def _atol_for_type(dtype):
572 """Return the absolute tolerance for a given dtype."""
573 return numpy.finfo(dtype).eps * 100