Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scikit_learn-1.4.dev0-py3.8-linux-x86_64.egg/sklearn/utils/_set_output.py: 53%

147 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1import importlib 

2from functools import wraps 

3from typing import Protocol, runtime_checkable 

4 

5import numpy as np 

6from scipy.sparse import issparse 

7 

8from .._config import get_config 

9from ._available_if import available_if 

10 

11 

12def check_library_installed(library): 

13 """Check library is installed.""" 

14 try: 

15 return importlib.import_module(library) 

16 except ImportError as exc: 

17 raise ImportError( 

18 f"Setting output container to '{library}' requires {library} to be" 

19 " installed" 

20 ) from exc 

21 

22 

23def get_columns(columns): 

24 if callable(columns): 

25 try: 

26 return columns() 

27 except Exception: 

28 return None 

29 return columns 

30 

31 

32@runtime_checkable 

33class ContainerAdapterProtocol(Protocol): 

34 container_lib: str 

35 

36 def create_container(self, X_output, X_original, columns): 

37 """Create container from `X_output` with additional metadata. 

38 

39 Parameters 

40 ---------- 

41 X_output : {ndarray, dataframe} 

42 Data to wrap. 

43 

44 X_original : {ndarray, dataframe} 

45 Original input dataframe. This is used to extract the metadata that should 

46 be passed to `X_output`, e.g. pandas row index. 

47 

48 columns : callable, ndarray, or None 

49 The column names or a callable that returns the column names. The 

50 callable is useful if the column names require some computation. If `None`, 

51 then no columns are passed to the container's constructor. 

52 

53 Returns 

54 ------- 

55 wrapped_output : container_type 

56 `X_output` wrapped into the container type. 

57 """ 

58 

59 def is_supported_container(self, X): 

60 """Return True if X is a supported container. 

61 

62 Parameters 

63 ---------- 

64 Xs: container 

65 Containers to be checked. 

66 

67 Returns 

68 ------- 

69 is_supported_container : bool 

70 True if X is a supported container. 

71 """ 

72 

73 def rename_columns(self, X, columns): 

74 """Rename columns in `X`. 

75 

76 Parameters 

77 ---------- 

78 X : container 

79 Container which columns is updated. 

80 

81 columns : ndarray of str 

82 Columns to update the `X`'s columns with. 

83 

84 Returns 

85 ------- 

86 updated_container : container 

87 Container with new names. 

88 """ 

89 

90 def hstack(self, Xs): 

91 """Stack containers horizontally (column-wise). 

92 

93 Parameters 

94 ---------- 

95 Xs : list of containers 

96 List of containers to stack. 

97 

98 Returns 

99 ------- 

100 stacked_Xs : container 

101 Stacked containers. 

102 """ 

103 

104 

105class PandasAdapter: 

106 container_lib = "pandas" 

107 

108 def create_container(self, X_output, X_original, columns): 

109 pd = check_library_installed("pandas") 

110 columns = get_columns(columns) 

111 index = X_original.index if isinstance(X_original, pd.DataFrame) else None 

112 

113 if isinstance(X_output, pd.DataFrame): 

114 if columns is not None: 

115 X_output.columns = columns 

116 return X_output 

117 

118 return pd.DataFrame(X_output, index=index, columns=columns, copy=False) 

119 

120 def is_supported_container(self, X): 

121 pd = check_library_installed("pandas") 

122 return isinstance(X, pd.DataFrame) 

123 

124 def rename_columns(self, X, columns): 

125 return X.rename(columns=dict(zip(X.columns, columns))) 

126 

127 def hstack(self, Xs): 

128 pd = check_library_installed("pandas") 

129 return pd.concat(Xs, axis=1) 

130 

131 

132class PolarsAdapter: 

133 container_lib = "polars" 

134 

135 def create_container(self, X_output, X_original, columns): 

136 pl = check_library_installed("polars") 

137 columns = get_columns(columns) 

138 

139 if isinstance(columns, np.ndarray): 

140 columns = columns.tolist() 

141 

142 if isinstance(X_output, pl.DataFrame): 

143 if columns is not None: 

144 return self.rename_columns(X_output, columns) 

145 return X_output 

146 

147 return pl.DataFrame(X_output, schema=columns, orient="row") 

148 

149 def is_supported_container(self, X): 

150 pl = check_library_installed("polars") 

151 return isinstance(X, pl.DataFrame) 

152 

153 def rename_columns(self, X, columns): 

154 return X.rename(dict(zip(X.columns, columns))) 

155 

156 def hstack(self, Xs): 

157 pl = check_library_installed("polars") 

158 return pl.concat(Xs, how="horizontal") 

159 

160 

161class ContainerAdaptersManager: 

162 def __init__(self): 

163 self.adapters = {} 

164 

165 @property 

166 def supported_outputs(self): 

167 return {"default"} | set(self.adapters) 

168 

169 def register(self, adapter): 

170 self.adapters[adapter.container_lib] = adapter 

171 

172 

173ADAPTERS_MANAGER = ContainerAdaptersManager() 

174ADAPTERS_MANAGER.register(PandasAdapter()) 

175ADAPTERS_MANAGER.register(PolarsAdapter()) 

176 

177 

178def _get_container_adapter(method, estimator=None): 

179 """Get container adapter.""" 

180 dense_config = _get_output_config(method, estimator)["dense"] 

181 try: 

182 return ADAPTERS_MANAGER.adapters[dense_config] 

183 except KeyError: 

184 return None 

185 

186 

187def _get_output_config(method, estimator=None): 

188 """Get output config based on estimator and global configuration. 

189 

190 Parameters 

191 ---------- 

192 method : {"transform"} 

193 Estimator's method for which the output container is looked up. 

194 

195 estimator : estimator instance or None 

196 Estimator to get the output configuration from. If `None`, check global 

197 configuration is used. 

198 

199 Returns 

200 ------- 

201 config : dict 

202 Dictionary with keys: 

203 

204 - "dense": specifies the dense container for `method`. This can be 

205 `"default"` or `"pandas"`. 

206 """ 

207 est_sklearn_output_config = getattr(estimator, "_sklearn_output_config", {}) 

208 if method in est_sklearn_output_config: 

209 dense_config = est_sklearn_output_config[method] 

210 else: 

211 dense_config = get_config()[f"{method}_output"] 

212 

213 supported_outputs = ADAPTERS_MANAGER.supported_outputs 

214 if dense_config not in supported_outputs: 

215 raise ValueError( 

216 f"output config must be in {sorted(supported_outputs)}, got {dense_config}" 

217 ) 

218 

219 return {"dense": dense_config} 

220 

221 

222def _wrap_data_with_container(method, data_to_wrap, original_input, estimator): 

223 """Wrap output with container based on an estimator's or global config. 

224 

225 Parameters 

226 ---------- 

227 method : {"transform"} 

228 Estimator's method to get container output for. 

229 

230 data_to_wrap : {ndarray, dataframe} 

231 Data to wrap with container. 

232 

233 original_input : {ndarray, dataframe} 

234 Original input of function. 

235 

236 estimator : estimator instance 

237 Estimator with to get the output configuration from. 

238 

239 Returns 

240 ------- 

241 output : {ndarray, dataframe} 

242 If the output config is "default" or the estimator is not configured 

243 for wrapping return `data_to_wrap` unchanged. 

244 If the output config is "pandas", return `data_to_wrap` as a pandas 

245 DataFrame. 

246 """ 

247 output_config = _get_output_config(method, estimator) 

248 

249 if output_config["dense"] == "default" or not _auto_wrap_is_configured(estimator): 

250 return data_to_wrap 

251 

252 dense_config = output_config["dense"] 

253 if issparse(data_to_wrap): 

254 raise ValueError( 

255 "The transformer outputs a scipy sparse matrix. " 

256 "Try to set the transformer output to a dense array or disable " 

257 f"{dense_config.capitalize()} output with set_output(transform='default')." 

258 ) 

259 

260 adapter = ADAPTERS_MANAGER.adapters[dense_config] 

261 return adapter.create_container( 

262 data_to_wrap, 

263 original_input, 

264 columns=estimator.get_feature_names_out, 

265 ) 

266 

267 

268def _wrap_method_output(f, method): 

269 """Wrapper used by `_SetOutputMixin` to automatically wrap methods.""" 

270 

271 @wraps(f) 

272 def wrapped(self, X, *args, **kwargs): 

273 data_to_wrap = f(self, X, *args, **kwargs) 

274 if isinstance(data_to_wrap, tuple): 

275 # only wrap the first output for cross decomposition 

276 return_tuple = ( 

277 _wrap_data_with_container(method, data_to_wrap[0], X, self), 

278 *data_to_wrap[1:], 

279 ) 

280 # Support for namedtuples `_make` is a documented API for namedtuples: 

281 # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make 

282 if hasattr(type(data_to_wrap), "_make"): 

283 return type(data_to_wrap)._make(return_tuple) 

284 return return_tuple 

285 

286 return _wrap_data_with_container(method, data_to_wrap, X, self) 

287 

288 return wrapped 

289 

290 

291def _auto_wrap_is_configured(estimator): 

292 """Return True if estimator is configured for auto-wrapping the transform method. 

293 

294 `_SetOutputMixin` sets `_sklearn_auto_wrap_output_keys` to `set()` if auto wrapping 

295 is manually disabled. 

296 """ 

297 auto_wrap_output_keys = getattr(estimator, "_sklearn_auto_wrap_output_keys", set()) 

298 return ( 

299 hasattr(estimator, "get_feature_names_out") 

300 and "transform" in auto_wrap_output_keys 

301 ) 

302 

303 

304class _SetOutputMixin: 

305 """Mixin that dynamically wraps methods to return container based on config. 

306 

307 Currently `_SetOutputMixin` wraps `transform` and `fit_transform` and configures 

308 it based on `set_output` of the global configuration. 

309 

310 `set_output` is only defined if `get_feature_names_out` is defined and 

311 `auto_wrap_output_keys` is the default value. 

312 """ 

313 

314 def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs): 

315 super().__init_subclass__(**kwargs) 

316 

317 # Dynamically wraps `transform` and `fit_transform` and configure it's 

318 # output based on `set_output`. 

319 if not ( 

320 isinstance(auto_wrap_output_keys, tuple) or auto_wrap_output_keys is None 

321 ): 

322 raise ValueError("auto_wrap_output_keys must be None or a tuple of keys.") 

323 

324 if auto_wrap_output_keys is None: 

325 cls._sklearn_auto_wrap_output_keys = set() 

326 return 

327 

328 # Mapping from method to key in configurations 

329 method_to_key = { 

330 "transform": "transform", 

331 "fit_transform": "transform", 

332 } 

333 cls._sklearn_auto_wrap_output_keys = set() 

334 

335 for method, key in method_to_key.items(): 

336 if not hasattr(cls, method) or key not in auto_wrap_output_keys: 

337 continue 

338 cls._sklearn_auto_wrap_output_keys.add(key) 

339 

340 # Only wrap methods defined by cls itself 

341 if method not in cls.__dict__: 

342 continue 

343 wrapped_method = _wrap_method_output(getattr(cls, method), key) 

344 setattr(cls, method, wrapped_method) 

345 

346 @available_if(_auto_wrap_is_configured) 

347 def set_output(self, *, transform=None): 

348 """Set output container. 

349 

350 See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py` 

351 for an example on how to use the API. 

352 

353 Parameters 

354 ---------- 

355 transform : {"default", "pandas"}, default=None 

356 Configure output of `transform` and `fit_transform`. 

357 

358 - `"default"`: Default output format of a transformer 

359 - `"pandas"`: DataFrame output 

360 - `"polars"`: Polars output 

361 - `None`: Transform configuration is unchanged 

362 

363 .. versionadded:: 1.4 

364 `"polars"` option was added. 

365 

366 Returns 

367 ------- 

368 self : estimator instance 

369 Estimator instance. 

370 """ 

371 if transform is None: 

372 return self 

373 

374 if not hasattr(self, "_sklearn_output_config"): 

375 self._sklearn_output_config = {} 

376 

377 self._sklearn_output_config["transform"] = transform 

378 return self 

379 

380 

381def _safe_set_output(estimator, *, transform=None): 

382 """Safely call estimator.set_output and error if it not available. 

383 

384 This is used by meta-estimators to set the output for child estimators. 

385 

386 Parameters 

387 ---------- 

388 estimator : estimator instance 

389 Estimator instance. 

390 

391 transform : {"default", "pandas"}, default=None 

392 Configure output of the following estimator's methods: 

393 

394 - `"transform"` 

395 - `"fit_transform"` 

396 

397 If `None`, this operation is a no-op. 

398 

399 Returns 

400 ------- 

401 estimator : estimator instance 

402 Estimator instance. 

403 """ 

404 set_output_for_transform = ( 

405 hasattr(estimator, "transform") 

406 or hasattr(estimator, "fit_transform") 

407 and transform is not None 

408 ) 

409 if not set_output_for_transform: 

410 # If estimator can not transform, then `set_output` does not need to be 

411 # called. 

412 return 

413 

414 if not hasattr(estimator, "set_output"): 

415 raise ValueError( 

416 f"Unable to configure output for {estimator} because `set_output` " 

417 "is not available." 

418 ) 

419 return estimator.set_output(transform=transform)