Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scikit_learn-1.4.dev0-py3.8-linux-x86_64.egg/sklearn/utils/fixes.py: 32%

209 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1"""Compatibility fixes for older version of python, numpy and scipy 

2 

3If you add content to this file, please give the version of the package 

4at which the fix is no longer needed. 

5""" 

6# Authors: Emmanuelle Gouillart <emmanuelle.gouillart@normalesup.org> 

7# Gael Varoquaux <gael.varoquaux@normalesup.org> 

8# Fabian Pedregosa <fpedregosa@acm.org> 

9# Lars Buitinck 

10# 

11# License: BSD 3 clause 

12 

13import sys 

14from importlib import resources 

15 

16import numpy as np 

17import scipy 

18import scipy.sparse.linalg 

19import scipy.stats 

20import threadpoolctl 

21 

22import sklearn 

23 

24from ..externals._packaging.version import parse as parse_version 

25from .deprecation import deprecated 

26 

27np_version = parse_version(np.__version__) 

28np_base_version = parse_version(np_version.base_version) 

29sp_version = parse_version(scipy.__version__) 

30sp_base_version = parse_version(sp_version.base_version) 

31 

32# TODO: We can consider removing the containers and importing 

33# directly from SciPy when sparse matrices will be deprecated. 

34CSR_CONTAINERS = [scipy.sparse.csr_matrix] 

35CSC_CONTAINERS = [scipy.sparse.csc_matrix] 

36COO_CONTAINERS = [scipy.sparse.coo_matrix] 

37LIL_CONTAINERS = [scipy.sparse.lil_matrix] 

38DOK_CONTAINERS = [scipy.sparse.dok_matrix] 

39BSR_CONTAINERS = [scipy.sparse.bsr_matrix] 

40DIA_CONTAINERS = [scipy.sparse.dia_matrix] 

41 

42if parse_version(scipy.__version__) >= parse_version("1.8"): 

43 # Sparse Arrays have been added in SciPy 1.8 

44 # TODO: When SciPy 1.8 is the minimum supported version, 

45 # those list can be created directly without this condition. 

46 # See: https://github.com/scikit-learn/scikit-learn/issues/27090 

47 CSR_CONTAINERS.append(scipy.sparse.csr_array) 

48 CSC_CONTAINERS.append(scipy.sparse.csc_array) 

49 COO_CONTAINERS.append(scipy.sparse.coo_array) 

50 LIL_CONTAINERS.append(scipy.sparse.lil_array) 

51 DOK_CONTAINERS.append(scipy.sparse.dok_array) 

52 BSR_CONTAINERS.append(scipy.sparse.bsr_array) 

53 DIA_CONTAINERS.append(scipy.sparse.dia_array) 

54 

55try: 

56 from scipy.optimize._linesearch import line_search_wolfe1, line_search_wolfe2 

57except ImportError: # SciPy < 1.8 

58 from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1 # type: ignore # noqa 

59 

60 

61def _object_dtype_isnan(X): 

62 return X != X 

63 

64 

65# Rename the `method` kwarg to `interpolation` for NumPy < 1.22, because 

66# `interpolation` kwarg was deprecated in favor of `method` in NumPy >= 1.22. 

67def _percentile(a, q, *, method="linear", **kwargs): 

68 return np.percentile(a, q, interpolation=method, **kwargs) 

69 

70 

71if np_version < parse_version("1.22"): 

72 percentile = _percentile 

73else: # >= 1.22 

74 from numpy import percentile # type: ignore # noqa 

75 

76 

77# compatibility fix for threadpoolctl >= 3.0.0 

78# since version 3 it's possible to setup a global threadpool controller to avoid 

79# looping through all loaded shared libraries each time. 

80# the global controller is created during the first call to threadpoolctl. 

81def _get_threadpool_controller(): 

82 if not hasattr(threadpoolctl, "ThreadpoolController"): 

83 return None 

84 

85 if not hasattr(sklearn, "_sklearn_threadpool_controller"): 

86 sklearn._sklearn_threadpool_controller = threadpoolctl.ThreadpoolController() 

87 

88 return sklearn._sklearn_threadpool_controller 

89 

90 

91def threadpool_limits(limits=None, user_api=None): 

92 controller = _get_threadpool_controller() 

93 if controller is not None: 

94 return controller.limit(limits=limits, user_api=user_api) 

95 else: 

96 return threadpoolctl.threadpool_limits(limits=limits, user_api=user_api) 

97 

98 

99threadpool_limits.__doc__ = threadpoolctl.threadpool_limits.__doc__ 

100 

101 

102def threadpool_info(): 

103 controller = _get_threadpool_controller() 

104 if controller is not None: 

105 return controller.info() 

106 else: 

107 return threadpoolctl.threadpool_info() 

108 

109 

110threadpool_info.__doc__ = threadpoolctl.threadpool_info.__doc__ 

111 

112 

113@deprecated( 

114 "The function `delayed` has been moved from `sklearn.utils.fixes` to " 

115 "`sklearn.utils.parallel`. This import path will be removed in 1.5." 

116) 

117def delayed(function): 

118 from sklearn.utils.parallel import delayed 

119 

120 return delayed(function) 

121 

122 

123# TODO: Remove when SciPy 1.11 is the minimum supported version 

124def _mode(a, axis=0): 

125 if sp_version >= parse_version("1.9.0"): 

126 mode = scipy.stats.mode(a, axis=axis, keepdims=True) 

127 if sp_version >= parse_version("1.10.999"): 

128 # scipy.stats.mode has changed returned array shape with axis=None 

129 # and keepdims=True, see https://github.com/scipy/scipy/pull/17561 

130 if axis is None: 

131 mode = np.ravel(mode) 

132 return mode 

133 return scipy.stats.mode(a, axis=axis) 

134 

135 

136# TODO: Remove when Scipy 1.12 is the minimum supported version 

137if sp_base_version >= parse_version("1.12.0"): 

138 _sparse_linalg_cg = scipy.sparse.linalg.cg 

139else: 

140 

141 def _sparse_linalg_cg(A, b, **kwargs): 

142 if "rtol" in kwargs: 

143 kwargs["tol"] = kwargs.pop("rtol") 

144 if "atol" not in kwargs: 

145 kwargs["atol"] = "legacy" 

146 return scipy.sparse.linalg.cg(A, b, **kwargs) 

147 

148 

149# TODO: Fuse the modern implementations of _sparse_min_max and _sparse_nan_min_max 

150# into the public min_max_axis function when Scipy 1.11 is the minimum supported 

151# version and delete the backport in the else branch below. 

152if sp_base_version >= parse_version("1.11.0"): 

153 

154 def _sparse_min_max(X, axis): 

155 the_min = X.min(axis=axis) 

156 the_max = X.max(axis=axis) 

157 

158 if axis is not None: 

159 the_min = the_min.toarray().ravel() 

160 the_max = the_max.toarray().ravel() 

161 

162 return the_min, the_max 

163 

164 def _sparse_nan_min_max(X, axis): 

165 the_min = X.nanmin(axis=axis) 

166 the_max = X.nanmax(axis=axis) 

167 

168 if axis is not None: 

169 the_min = the_min.toarray().ravel() 

170 the_max = the_max.toarray().ravel() 

171 

172 return the_min, the_max 

173 

174else: 

175 # This code is mostly taken from scipy 0.14 and extended to handle nans, see 

176 # https://github.com/scikit-learn/scikit-learn/pull/11196 

177 def _minor_reduce(X, ufunc): 

178 major_index = np.flatnonzero(np.diff(X.indptr)) 

179 

180 # reduceat tries casts X.indptr to intp, which errors 

181 # if it is int64 on a 32 bit system. 

182 # Reinitializing prevents this where possible, see #13737 

183 X = type(X)((X.data, X.indices, X.indptr), shape=X.shape) 

184 value = ufunc.reduceat(X.data, X.indptr[major_index]) 

185 return major_index, value 

186 

187 def _min_or_max_axis(X, axis, min_or_max): 

188 N = X.shape[axis] 

189 if N == 0: 

190 raise ValueError("zero-size array to reduction operation") 

191 M = X.shape[1 - axis] 

192 mat = X.tocsc() if axis == 0 else X.tocsr() 

193 mat.sum_duplicates() 

194 major_index, value = _minor_reduce(mat, min_or_max) 

195 not_full = np.diff(mat.indptr)[major_index] < N 

196 value[not_full] = min_or_max(value[not_full], 0) 

197 mask = value != 0 

198 major_index = np.compress(mask, major_index) 

199 value = np.compress(mask, value) 

200 

201 if axis == 0: 

202 res = scipy.sparse.coo_matrix( 

203 (value, (np.zeros(len(value)), major_index)), 

204 dtype=X.dtype, 

205 shape=(1, M), 

206 ) 

207 else: 

208 res = scipy.sparse.coo_matrix( 

209 (value, (major_index, np.zeros(len(value)))), 

210 dtype=X.dtype, 

211 shape=(M, 1), 

212 ) 

213 return res.A.ravel() 

214 

215 def _sparse_min_or_max(X, axis, min_or_max): 

216 if axis is None: 

217 if 0 in X.shape: 

218 raise ValueError("zero-size array to reduction operation") 

219 zero = X.dtype.type(0) 

220 if X.nnz == 0: 

221 return zero 

222 m = min_or_max.reduce(X.data.ravel()) 

223 if X.nnz != np.prod(X.shape): 

224 m = min_or_max(zero, m) 

225 return m 

226 if axis < 0: 

227 axis += 2 

228 if (axis == 0) or (axis == 1): 

229 return _min_or_max_axis(X, axis, min_or_max) 

230 else: 

231 raise ValueError("invalid axis, use 0 for rows, or 1 for columns") 

232 

233 def _sparse_min_max(X, axis): 

234 return ( 

235 _sparse_min_or_max(X, axis, np.minimum), 

236 _sparse_min_or_max(X, axis, np.maximum), 

237 ) 

238 

239 def _sparse_nan_min_max(X, axis): 

240 return ( 

241 _sparse_min_or_max(X, axis, np.fmin), 

242 _sparse_min_or_max(X, axis, np.fmax), 

243 ) 

244 

245 

246############################################################################### 

247# Backport of Python 3.9's importlib.resources 

248# TODO: Remove when Python 3.9 is the minimum supported version 

249 

250 

251def _open_text(data_module, data_file_name): 

252 if sys.version_info >= (3, 9): 

253 return resources.files(data_module).joinpath(data_file_name).open("r") 

254 else: 

255 return resources.open_text(data_module, data_file_name) 

256 

257 

258def _open_binary(data_module, data_file_name): 

259 if sys.version_info >= (3, 9): 

260 return resources.files(data_module).joinpath(data_file_name).open("rb") 

261 else: 

262 return resources.open_binary(data_module, data_file_name) 

263 

264 

265def _read_text(descr_module, descr_file_name): 

266 if sys.version_info >= (3, 9): 

267 return resources.files(descr_module).joinpath(descr_file_name).read_text() 

268 else: 

269 return resources.read_text(descr_module, descr_file_name) 

270 

271 

272def _path(data_module, data_file_name): 

273 if sys.version_info >= (3, 9): 

274 return resources.as_file(resources.files(data_module).joinpath(data_file_name)) 

275 else: 

276 return resources.path(data_module, data_file_name) 

277 

278 

279def _is_resource(data_module, data_file_name): 

280 if sys.version_info >= (3, 9): 

281 return resources.files(data_module).joinpath(data_file_name).is_file() 

282 else: 

283 return resources.is_resource(data_module, data_file_name) 

284 

285 

286def _contents(data_module): 

287 if sys.version_info >= (3, 9): 

288 return ( 

289 resource.name 

290 for resource in resources.files(data_module).iterdir() 

291 if resource.is_file() 

292 ) 

293 else: 

294 return resources.contents(data_module) 

295 

296 

297# For +1.25 NumPy versions exceptions and warnings are being moved 

298# to a dedicated submodule. 

299if np_version >= parse_version("1.25.0"): 

300 from numpy.exceptions import ComplexWarning, VisibleDeprecationWarning 

301else: 

302 from numpy import ComplexWarning, VisibleDeprecationWarning # type: ignore # noqa 

303 

304 

305# TODO: Remove when Scipy 1.6 is the minimum supported version 

306try: 

307 from scipy.integrate import trapezoid # type: ignore # noqa 

308except ImportError: 

309 from scipy.integrate import trapz as trapezoid # type: ignore # noqa 

310 

311 

312# TODO: Remove when Pandas > 2.2 is the minimum supported version 

313def pd_fillna(pd, frame): 

314 pd_version = parse_version(pd.__version__).base_version 

315 if parse_version(pd_version) < parse_version("2.2"): 

316 frame = frame.fillna(value=np.nan) 

317 else: 

318 with pd.option_context("future.no_silent_downcasting", True): 

319 frame = frame.fillna(value=np.nan).infer_objects(copy=False) 

320 return frame 

321 

322 

323# TODO: remove when SciPy 1.12 is the minimum supported version 

324def _preserve_dia_indices_dtype( 

325 sparse_container, original_container_format, requested_sparse_format 

326): 

327 """Preserve indices dtype for SciPy < 1.12 when converting from DIA to CSR/CSC. 

328 

329 For SciPy < 1.12, DIA arrays indices are upcasted to `np.int64` that is 

330 inconsistent with DIA matrices. We downcast the indices dtype to `np.int32` to 

331 be consistent with DIA matrices. 

332 

333 The converted indices arrays are affected back inplace to the sparse container. 

334 

335 Parameters 

336 ---------- 

337 sparse_container : sparse container 

338 Sparse container to be checked. 

339 requested_sparse_format : str or bool 

340 The type of format of `sparse_container`. 

341 

342 Notes 

343 ----- 

344 See https://github.com/scipy/scipy/issues/19245 for more details. 

345 """ 

346 if original_container_format == "dia_array" and requested_sparse_format in ( 

347 "csr", 

348 "coo", 

349 ): 

350 if requested_sparse_format == "csr": 

351 index_dtype = _smallest_admissible_index_dtype( 

352 arrays=(sparse_container.indptr, sparse_container.indices), 

353 maxval=max(sparse_container.nnz, sparse_container.shape[1]), 

354 check_contents=True, 

355 ) 

356 sparse_container.indices = sparse_container.indices.astype( 

357 index_dtype, copy=False 

358 ) 

359 sparse_container.indptr = sparse_container.indptr.astype( 

360 index_dtype, copy=False 

361 ) 

362 else: # requested_sparse_format == "coo" 

363 index_dtype = _smallest_admissible_index_dtype( 

364 maxval=max(sparse_container.shape) 

365 ) 

366 sparse_container.row = sparse_container.row.astype(index_dtype, copy=False) 

367 sparse_container.col = sparse_container.col.astype(index_dtype, copy=False) 

368 

369 

370# TODO: remove when SciPy 1.12 is the minimum supported version 

371def _smallest_admissible_index_dtype(arrays=(), maxval=None, check_contents=False): 

372 """Based on input (integer) arrays `a`, determine a suitable index data 

373 type that can hold the data in the arrays. 

374 

375 This function returns `np.int64` if it either required by `maxval` or based on the 

376 largest precision of the dtype of the arrays passed as argument, or by the their 

377 contents (when `check_contents is True`). If none of the condition requires 

378 `np.int64` then this function returns `np.int32`. 

379 

380 Parameters 

381 ---------- 

382 arrays : ndarray or tuple of ndarrays, default=() 

383 Input arrays whose types/contents to check. 

384 

385 maxval : float, default=None 

386 Maximum value needed. 

387 

388 check_contents : bool, default=False 

389 Whether to check the values in the arrays and not just their types. 

390 By default, check only the types. 

391 

392 Returns 

393 ------- 

394 dtype : {np.int32, np.int64} 

395 Suitable index data type (int32 or int64). 

396 """ 

397 

398 int32min = np.int32(np.iinfo(np.int32).min) 

399 int32max = np.int32(np.iinfo(np.int32).max) 

400 

401 if maxval is not None: 

402 if maxval > np.iinfo(np.int64).max: 

403 raise ValueError( 

404 f"maxval={maxval} is to large to be represented as np.int64." 

405 ) 

406 if maxval > int32max: 

407 return np.int64 

408 

409 if isinstance(arrays, np.ndarray): 

410 arrays = (arrays,) 

411 

412 for arr in arrays: 

413 if not isinstance(arr, np.ndarray): 

414 raise TypeError( 

415 f"Arrays should be of type np.ndarray, got {type(arr)} instead." 

416 ) 

417 if not np.issubdtype(arr.dtype, np.integer): 

418 raise ValueError( 

419 f"Array dtype {arr.dtype} is not supported for index dtype. We expect " 

420 "integral values." 

421 ) 

422 if not np.can_cast(arr.dtype, np.int32): 

423 if not check_contents: 

424 # when `check_contents` is False, we stay on the safe side and return 

425 # np.int64. 

426 return np.int64 

427 if arr.size == 0: 

428 # a bigger type not needed yet, let's look at the next array 

429 continue 

430 else: 

431 maxval = arr.max() 

432 minval = arr.min() 

433 if minval < int32min or maxval > int32max: 

434 # a big index type is actually needed 

435 return np.int64 

436 

437 return np.int32 

438 

439 

440# TODO: Remove when Scipy 1.12 is the minimum supported version 

441if sp_version < parse_version("1.12"): 

442 from ..externals._scipy.sparse.csgraph import laplacian # type: ignore # noqa 

443else: 

444 from scipy.sparse.csgraph import laplacian # type: ignore # noqa # pragma: no cover