1"""Compatibility fixes for older version of python, numpy and scipy
2
3If you add content to this file, please give the version of the package
4at which the fix is no longer needed.
5"""
6# Authors: Emmanuelle Gouillart <emmanuelle.gouillart@normalesup.org>
7# Gael Varoquaux <gael.varoquaux@normalesup.org>
8# Fabian Pedregosa <fpedregosa@acm.org>
9# Lars Buitinck
10#
11# License: BSD 3 clause
12
13import sys
14from importlib import resources
15
16import numpy as np
17import scipy
18import scipy.sparse.linalg
19import scipy.stats
20import threadpoolctl
21
22import sklearn
23
24from ..externals._packaging.version import parse as parse_version
25from .deprecation import deprecated
26
27np_version = parse_version(np.__version__)
28np_base_version = parse_version(np_version.base_version)
29sp_version = parse_version(scipy.__version__)
30sp_base_version = parse_version(sp_version.base_version)
31
32# TODO: We can consider removing the containers and importing
33# directly from SciPy when sparse matrices will be deprecated.
34CSR_CONTAINERS = [scipy.sparse.csr_matrix]
35CSC_CONTAINERS = [scipy.sparse.csc_matrix]
36COO_CONTAINERS = [scipy.sparse.coo_matrix]
37LIL_CONTAINERS = [scipy.sparse.lil_matrix]
38DOK_CONTAINERS = [scipy.sparse.dok_matrix]
39BSR_CONTAINERS = [scipy.sparse.bsr_matrix]
40DIA_CONTAINERS = [scipy.sparse.dia_matrix]
41
42if parse_version(scipy.__version__) >= parse_version("1.8"):
43 # Sparse Arrays have been added in SciPy 1.8
44 # TODO: When SciPy 1.8 is the minimum supported version,
45 # those list can be created directly without this condition.
46 # See: https://github.com/scikit-learn/scikit-learn/issues/27090
47 CSR_CONTAINERS.append(scipy.sparse.csr_array)
48 CSC_CONTAINERS.append(scipy.sparse.csc_array)
49 COO_CONTAINERS.append(scipy.sparse.coo_array)
50 LIL_CONTAINERS.append(scipy.sparse.lil_array)
51 DOK_CONTAINERS.append(scipy.sparse.dok_array)
52 BSR_CONTAINERS.append(scipy.sparse.bsr_array)
53 DIA_CONTAINERS.append(scipy.sparse.dia_array)
54
55try:
56 from scipy.optimize._linesearch import line_search_wolfe1, line_search_wolfe2
57except ImportError: # SciPy < 1.8
58 from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1 # type: ignore # noqa
59
60
61def _object_dtype_isnan(X):
62 return X != X
63
64
65# Rename the `method` kwarg to `interpolation` for NumPy < 1.22, because
66# `interpolation` kwarg was deprecated in favor of `method` in NumPy >= 1.22.
67def _percentile(a, q, *, method="linear", **kwargs):
68 return np.percentile(a, q, interpolation=method, **kwargs)
69
70
71if np_version < parse_version("1.22"):
72 percentile = _percentile
73else: # >= 1.22
74 from numpy import percentile # type: ignore # noqa
75
76
77# compatibility fix for threadpoolctl >= 3.0.0
78# since version 3 it's possible to setup a global threadpool controller to avoid
79# looping through all loaded shared libraries each time.
80# the global controller is created during the first call to threadpoolctl.
81def _get_threadpool_controller():
82 if not hasattr(threadpoolctl, "ThreadpoolController"):
83 return None
84
85 if not hasattr(sklearn, "_sklearn_threadpool_controller"):
86 sklearn._sklearn_threadpool_controller = threadpoolctl.ThreadpoolController()
87
88 return sklearn._sklearn_threadpool_controller
89
90
91def threadpool_limits(limits=None, user_api=None):
92 controller = _get_threadpool_controller()
93 if controller is not None:
94 return controller.limit(limits=limits, user_api=user_api)
95 else:
96 return threadpoolctl.threadpool_limits(limits=limits, user_api=user_api)
97
98
99threadpool_limits.__doc__ = threadpoolctl.threadpool_limits.__doc__
100
101
102def threadpool_info():
103 controller = _get_threadpool_controller()
104 if controller is not None:
105 return controller.info()
106 else:
107 return threadpoolctl.threadpool_info()
108
109
110threadpool_info.__doc__ = threadpoolctl.threadpool_info.__doc__
111
112
113@deprecated(
114 "The function `delayed` has been moved from `sklearn.utils.fixes` to "
115 "`sklearn.utils.parallel`. This import path will be removed in 1.5."
116)
117def delayed(function):
118 from sklearn.utils.parallel import delayed
119
120 return delayed(function)
121
122
123# TODO: Remove when SciPy 1.11 is the minimum supported version
124def _mode(a, axis=0):
125 if sp_version >= parse_version("1.9.0"):
126 mode = scipy.stats.mode(a, axis=axis, keepdims=True)
127 if sp_version >= parse_version("1.10.999"):
128 # scipy.stats.mode has changed returned array shape with axis=None
129 # and keepdims=True, see https://github.com/scipy/scipy/pull/17561
130 if axis is None:
131 mode = np.ravel(mode)
132 return mode
133 return scipy.stats.mode(a, axis=axis)
134
135
136# TODO: Remove when Scipy 1.12 is the minimum supported version
137if sp_base_version >= parse_version("1.12.0"):
138 _sparse_linalg_cg = scipy.sparse.linalg.cg
139else:
140
141 def _sparse_linalg_cg(A, b, **kwargs):
142 if "rtol" in kwargs:
143 kwargs["tol"] = kwargs.pop("rtol")
144 if "atol" not in kwargs:
145 kwargs["atol"] = "legacy"
146 return scipy.sparse.linalg.cg(A, b, **kwargs)
147
148
149# TODO: Fuse the modern implementations of _sparse_min_max and _sparse_nan_min_max
150# into the public min_max_axis function when Scipy 1.11 is the minimum supported
151# version and delete the backport in the else branch below.
152if sp_base_version >= parse_version("1.11.0"):
153
154 def _sparse_min_max(X, axis):
155 the_min = X.min(axis=axis)
156 the_max = X.max(axis=axis)
157
158 if axis is not None:
159 the_min = the_min.toarray().ravel()
160 the_max = the_max.toarray().ravel()
161
162 return the_min, the_max
163
164 def _sparse_nan_min_max(X, axis):
165 the_min = X.nanmin(axis=axis)
166 the_max = X.nanmax(axis=axis)
167
168 if axis is not None:
169 the_min = the_min.toarray().ravel()
170 the_max = the_max.toarray().ravel()
171
172 return the_min, the_max
173
174else:
175 # This code is mostly taken from scipy 0.14 and extended to handle nans, see
176 # https://github.com/scikit-learn/scikit-learn/pull/11196
177 def _minor_reduce(X, ufunc):
178 major_index = np.flatnonzero(np.diff(X.indptr))
179
180 # reduceat tries casts X.indptr to intp, which errors
181 # if it is int64 on a 32 bit system.
182 # Reinitializing prevents this where possible, see #13737
183 X = type(X)((X.data, X.indices, X.indptr), shape=X.shape)
184 value = ufunc.reduceat(X.data, X.indptr[major_index])
185 return major_index, value
186
187 def _min_or_max_axis(X, axis, min_or_max):
188 N = X.shape[axis]
189 if N == 0:
190 raise ValueError("zero-size array to reduction operation")
191 M = X.shape[1 - axis]
192 mat = X.tocsc() if axis == 0 else X.tocsr()
193 mat.sum_duplicates()
194 major_index, value = _minor_reduce(mat, min_or_max)
195 not_full = np.diff(mat.indptr)[major_index] < N
196 value[not_full] = min_or_max(value[not_full], 0)
197 mask = value != 0
198 major_index = np.compress(mask, major_index)
199 value = np.compress(mask, value)
200
201 if axis == 0:
202 res = scipy.sparse.coo_matrix(
203 (value, (np.zeros(len(value)), major_index)),
204 dtype=X.dtype,
205 shape=(1, M),
206 )
207 else:
208 res = scipy.sparse.coo_matrix(
209 (value, (major_index, np.zeros(len(value)))),
210 dtype=X.dtype,
211 shape=(M, 1),
212 )
213 return res.A.ravel()
214
215 def _sparse_min_or_max(X, axis, min_or_max):
216 if axis is None:
217 if 0 in X.shape:
218 raise ValueError("zero-size array to reduction operation")
219 zero = X.dtype.type(0)
220 if X.nnz == 0:
221 return zero
222 m = min_or_max.reduce(X.data.ravel())
223 if X.nnz != np.prod(X.shape):
224 m = min_or_max(zero, m)
225 return m
226 if axis < 0:
227 axis += 2
228 if (axis == 0) or (axis == 1):
229 return _min_or_max_axis(X, axis, min_or_max)
230 else:
231 raise ValueError("invalid axis, use 0 for rows, or 1 for columns")
232
233 def _sparse_min_max(X, axis):
234 return (
235 _sparse_min_or_max(X, axis, np.minimum),
236 _sparse_min_or_max(X, axis, np.maximum),
237 )
238
239 def _sparse_nan_min_max(X, axis):
240 return (
241 _sparse_min_or_max(X, axis, np.fmin),
242 _sparse_min_or_max(X, axis, np.fmax),
243 )
244
245
246###############################################################################
247# Backport of Python 3.9's importlib.resources
248# TODO: Remove when Python 3.9 is the minimum supported version
249
250
251def _open_text(data_module, data_file_name):
252 if sys.version_info >= (3, 9):
253 return resources.files(data_module).joinpath(data_file_name).open("r")
254 else:
255 return resources.open_text(data_module, data_file_name)
256
257
258def _open_binary(data_module, data_file_name):
259 if sys.version_info >= (3, 9):
260 return resources.files(data_module).joinpath(data_file_name).open("rb")
261 else:
262 return resources.open_binary(data_module, data_file_name)
263
264
265def _read_text(descr_module, descr_file_name):
266 if sys.version_info >= (3, 9):
267 return resources.files(descr_module).joinpath(descr_file_name).read_text()
268 else:
269 return resources.read_text(descr_module, descr_file_name)
270
271
272def _path(data_module, data_file_name):
273 if sys.version_info >= (3, 9):
274 return resources.as_file(resources.files(data_module).joinpath(data_file_name))
275 else:
276 return resources.path(data_module, data_file_name)
277
278
279def _is_resource(data_module, data_file_name):
280 if sys.version_info >= (3, 9):
281 return resources.files(data_module).joinpath(data_file_name).is_file()
282 else:
283 return resources.is_resource(data_module, data_file_name)
284
285
286def _contents(data_module):
287 if sys.version_info >= (3, 9):
288 return (
289 resource.name
290 for resource in resources.files(data_module).iterdir()
291 if resource.is_file()
292 )
293 else:
294 return resources.contents(data_module)
295
296
297# For +1.25 NumPy versions exceptions and warnings are being moved
298# to a dedicated submodule.
299if np_version >= parse_version("1.25.0"):
300 from numpy.exceptions import ComplexWarning, VisibleDeprecationWarning
301else:
302 from numpy import ComplexWarning, VisibleDeprecationWarning # type: ignore # noqa
303
304
305# TODO: Remove when Scipy 1.6 is the minimum supported version
306try:
307 from scipy.integrate import trapezoid # type: ignore # noqa
308except ImportError:
309 from scipy.integrate import trapz as trapezoid # type: ignore # noqa
310
311
312# TODO: Remove when Pandas > 2.2 is the minimum supported version
313def pd_fillna(pd, frame):
314 pd_version = parse_version(pd.__version__).base_version
315 if parse_version(pd_version) < parse_version("2.2"):
316 frame = frame.fillna(value=np.nan)
317 else:
318 with pd.option_context("future.no_silent_downcasting", True):
319 frame = frame.fillna(value=np.nan).infer_objects(copy=False)
320 return frame
321
322
323# TODO: remove when SciPy 1.12 is the minimum supported version
324def _preserve_dia_indices_dtype(
325 sparse_container, original_container_format, requested_sparse_format
326):
327 """Preserve indices dtype for SciPy < 1.12 when converting from DIA to CSR/CSC.
328
329 For SciPy < 1.12, DIA arrays indices are upcasted to `np.int64` that is
330 inconsistent with DIA matrices. We downcast the indices dtype to `np.int32` to
331 be consistent with DIA matrices.
332
333 The converted indices arrays are affected back inplace to the sparse container.
334
335 Parameters
336 ----------
337 sparse_container : sparse container
338 Sparse container to be checked.
339 requested_sparse_format : str or bool
340 The type of format of `sparse_container`.
341
342 Notes
343 -----
344 See https://github.com/scipy/scipy/issues/19245 for more details.
345 """
346 if original_container_format == "dia_array" and requested_sparse_format in (
347 "csr",
348 "coo",
349 ):
350 if requested_sparse_format == "csr":
351 index_dtype = _smallest_admissible_index_dtype(
352 arrays=(sparse_container.indptr, sparse_container.indices),
353 maxval=max(sparse_container.nnz, sparse_container.shape[1]),
354 check_contents=True,
355 )
356 sparse_container.indices = sparse_container.indices.astype(
357 index_dtype, copy=False
358 )
359 sparse_container.indptr = sparse_container.indptr.astype(
360 index_dtype, copy=False
361 )
362 else: # requested_sparse_format == "coo"
363 index_dtype = _smallest_admissible_index_dtype(
364 maxval=max(sparse_container.shape)
365 )
366 sparse_container.row = sparse_container.row.astype(index_dtype, copy=False)
367 sparse_container.col = sparse_container.col.astype(index_dtype, copy=False)
368
369
370# TODO: remove when SciPy 1.12 is the minimum supported version
371def _smallest_admissible_index_dtype(arrays=(), maxval=None, check_contents=False):
372 """Based on input (integer) arrays `a`, determine a suitable index data
373 type that can hold the data in the arrays.
374
375 This function returns `np.int64` if it either required by `maxval` or based on the
376 largest precision of the dtype of the arrays passed as argument, or by the their
377 contents (when `check_contents is True`). If none of the condition requires
378 `np.int64` then this function returns `np.int32`.
379
380 Parameters
381 ----------
382 arrays : ndarray or tuple of ndarrays, default=()
383 Input arrays whose types/contents to check.
384
385 maxval : float, default=None
386 Maximum value needed.
387
388 check_contents : bool, default=False
389 Whether to check the values in the arrays and not just their types.
390 By default, check only the types.
391
392 Returns
393 -------
394 dtype : {np.int32, np.int64}
395 Suitable index data type (int32 or int64).
396 """
397
398 int32min = np.int32(np.iinfo(np.int32).min)
399 int32max = np.int32(np.iinfo(np.int32).max)
400
401 if maxval is not None:
402 if maxval > np.iinfo(np.int64).max:
403 raise ValueError(
404 f"maxval={maxval} is to large to be represented as np.int64."
405 )
406 if maxval > int32max:
407 return np.int64
408
409 if isinstance(arrays, np.ndarray):
410 arrays = (arrays,)
411
412 for arr in arrays:
413 if not isinstance(arr, np.ndarray):
414 raise TypeError(
415 f"Arrays should be of type np.ndarray, got {type(arr)} instead."
416 )
417 if not np.issubdtype(arr.dtype, np.integer):
418 raise ValueError(
419 f"Array dtype {arr.dtype} is not supported for index dtype. We expect "
420 "integral values."
421 )
422 if not np.can_cast(arr.dtype, np.int32):
423 if not check_contents:
424 # when `check_contents` is False, we stay on the safe side and return
425 # np.int64.
426 return np.int64
427 if arr.size == 0:
428 # a bigger type not needed yet, let's look at the next array
429 continue
430 else:
431 maxval = arr.max()
432 minval = arr.min()
433 if minval < int32min or maxval > int32max:
434 # a big index type is actually needed
435 return np.int64
436
437 return np.int32
438
439
440# TODO: Remove when Scipy 1.12 is the minimum supported version
441if sp_version < parse_version("1.12"):
442 from ..externals._scipy.sparse.csgraph import laplacian # type: ignore # noqa
443else:
444 from scipy.sparse.csgraph import laplacian # type: ignore # noqa # pragma: no cover