1import importlib
2from functools import wraps
3from typing import Protocol, runtime_checkable
4
5import numpy as np
6from scipy.sparse import issparse
7
8from .._config import get_config
9from ._available_if import available_if
10
11
12def check_library_installed(library):
13 """Check library is installed."""
14 try:
15 return importlib.import_module(library)
16 except ImportError as exc:
17 raise ImportError(
18 f"Setting output container to '{library}' requires {library} to be"
19 " installed"
20 ) from exc
21
22
23def get_columns(columns):
24 if callable(columns):
25 try:
26 return columns()
27 except Exception:
28 return None
29 return columns
30
31
32@runtime_checkable
33class ContainerAdapterProtocol(Protocol):
34 container_lib: str
35
36 def create_container(self, X_output, X_original, columns):
37 """Create container from `X_output` with additional metadata.
38
39 Parameters
40 ----------
41 X_output : {ndarray, dataframe}
42 Data to wrap.
43
44 X_original : {ndarray, dataframe}
45 Original input dataframe. This is used to extract the metadata that should
46 be passed to `X_output`, e.g. pandas row index.
47
48 columns : callable, ndarray, or None
49 The column names or a callable that returns the column names. The
50 callable is useful if the column names require some computation. If `None`,
51 then no columns are passed to the container's constructor.
52
53 Returns
54 -------
55 wrapped_output : container_type
56 `X_output` wrapped into the container type.
57 """
58
59 def is_supported_container(self, X):
60 """Return True if X is a supported container.
61
62 Parameters
63 ----------
64 Xs: container
65 Containers to be checked.
66
67 Returns
68 -------
69 is_supported_container : bool
70 True if X is a supported container.
71 """
72
73 def rename_columns(self, X, columns):
74 """Rename columns in `X`.
75
76 Parameters
77 ----------
78 X : container
79 Container which columns is updated.
80
81 columns : ndarray of str
82 Columns to update the `X`'s columns with.
83
84 Returns
85 -------
86 updated_container : container
87 Container with new names.
88 """
89
90 def hstack(self, Xs):
91 """Stack containers horizontally (column-wise).
92
93 Parameters
94 ----------
95 Xs : list of containers
96 List of containers to stack.
97
98 Returns
99 -------
100 stacked_Xs : container
101 Stacked containers.
102 """
103
104
105class PandasAdapter:
106 container_lib = "pandas"
107
108 def create_container(self, X_output, X_original, columns):
109 pd = check_library_installed("pandas")
110 columns = get_columns(columns)
111 index = X_original.index if isinstance(X_original, pd.DataFrame) else None
112
113 if isinstance(X_output, pd.DataFrame):
114 if columns is not None:
115 X_output.columns = columns
116 return X_output
117
118 return pd.DataFrame(X_output, index=index, columns=columns, copy=False)
119
120 def is_supported_container(self, X):
121 pd = check_library_installed("pandas")
122 return isinstance(X, pd.DataFrame)
123
124 def rename_columns(self, X, columns):
125 return X.rename(columns=dict(zip(X.columns, columns)))
126
127 def hstack(self, Xs):
128 pd = check_library_installed("pandas")
129 return pd.concat(Xs, axis=1)
130
131
132class PolarsAdapter:
133 container_lib = "polars"
134
135 def create_container(self, X_output, X_original, columns):
136 pl = check_library_installed("polars")
137 columns = get_columns(columns)
138
139 if isinstance(columns, np.ndarray):
140 columns = columns.tolist()
141
142 if isinstance(X_output, pl.DataFrame):
143 if columns is not None:
144 return self.rename_columns(X_output, columns)
145 return X_output
146
147 return pl.DataFrame(X_output, schema=columns, orient="row")
148
149 def is_supported_container(self, X):
150 pl = check_library_installed("polars")
151 return isinstance(X, pl.DataFrame)
152
153 def rename_columns(self, X, columns):
154 return X.rename(dict(zip(X.columns, columns)))
155
156 def hstack(self, Xs):
157 pl = check_library_installed("polars")
158 return pl.concat(Xs, how="horizontal")
159
160
161class ContainerAdaptersManager:
162 def __init__(self):
163 self.adapters = {}
164
165 @property
166 def supported_outputs(self):
167 return {"default"} | set(self.adapters)
168
169 def register(self, adapter):
170 self.adapters[adapter.container_lib] = adapter
171
172
173ADAPTERS_MANAGER = ContainerAdaptersManager()
174ADAPTERS_MANAGER.register(PandasAdapter())
175ADAPTERS_MANAGER.register(PolarsAdapter())
176
177
178def _get_container_adapter(method, estimator=None):
179 """Get container adapter."""
180 dense_config = _get_output_config(method, estimator)["dense"]
181 try:
182 return ADAPTERS_MANAGER.adapters[dense_config]
183 except KeyError:
184 return None
185
186
187def _get_output_config(method, estimator=None):
188 """Get output config based on estimator and global configuration.
189
190 Parameters
191 ----------
192 method : {"transform"}
193 Estimator's method for which the output container is looked up.
194
195 estimator : estimator instance or None
196 Estimator to get the output configuration from. If `None`, check global
197 configuration is used.
198
199 Returns
200 -------
201 config : dict
202 Dictionary with keys:
203
204 - "dense": specifies the dense container for `method`. This can be
205 `"default"` or `"pandas"`.
206 """
207 est_sklearn_output_config = getattr(estimator, "_sklearn_output_config", {})
208 if method in est_sklearn_output_config:
209 dense_config = est_sklearn_output_config[method]
210 else:
211 dense_config = get_config()[f"{method}_output"]
212
213 supported_outputs = ADAPTERS_MANAGER.supported_outputs
214 if dense_config not in supported_outputs:
215 raise ValueError(
216 f"output config must be in {sorted(supported_outputs)}, got {dense_config}"
217 )
218
219 return {"dense": dense_config}
220
221
222def _wrap_data_with_container(method, data_to_wrap, original_input, estimator):
223 """Wrap output with container based on an estimator's or global config.
224
225 Parameters
226 ----------
227 method : {"transform"}
228 Estimator's method to get container output for.
229
230 data_to_wrap : {ndarray, dataframe}
231 Data to wrap with container.
232
233 original_input : {ndarray, dataframe}
234 Original input of function.
235
236 estimator : estimator instance
237 Estimator with to get the output configuration from.
238
239 Returns
240 -------
241 output : {ndarray, dataframe}
242 If the output config is "default" or the estimator is not configured
243 for wrapping return `data_to_wrap` unchanged.
244 If the output config is "pandas", return `data_to_wrap` as a pandas
245 DataFrame.
246 """
247 output_config = _get_output_config(method, estimator)
248
249 if output_config["dense"] == "default" or not _auto_wrap_is_configured(estimator):
250 return data_to_wrap
251
252 dense_config = output_config["dense"]
253 if issparse(data_to_wrap):
254 raise ValueError(
255 "The transformer outputs a scipy sparse matrix. "
256 "Try to set the transformer output to a dense array or disable "
257 f"{dense_config.capitalize()} output with set_output(transform='default')."
258 )
259
260 adapter = ADAPTERS_MANAGER.adapters[dense_config]
261 return adapter.create_container(
262 data_to_wrap,
263 original_input,
264 columns=estimator.get_feature_names_out,
265 )
266
267
268def _wrap_method_output(f, method):
269 """Wrapper used by `_SetOutputMixin` to automatically wrap methods."""
270
271 @wraps(f)
272 def wrapped(self, X, *args, **kwargs):
273 data_to_wrap = f(self, X, *args, **kwargs)
274 if isinstance(data_to_wrap, tuple):
275 # only wrap the first output for cross decomposition
276 return_tuple = (
277 _wrap_data_with_container(method, data_to_wrap[0], X, self),
278 *data_to_wrap[1:],
279 )
280 # Support for namedtuples `_make` is a documented API for namedtuples:
281 # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make
282 if hasattr(type(data_to_wrap), "_make"):
283 return type(data_to_wrap)._make(return_tuple)
284 return return_tuple
285
286 return _wrap_data_with_container(method, data_to_wrap, X, self)
287
288 return wrapped
289
290
291def _auto_wrap_is_configured(estimator):
292 """Return True if estimator is configured for auto-wrapping the transform method.
293
294 `_SetOutputMixin` sets `_sklearn_auto_wrap_output_keys` to `set()` if auto wrapping
295 is manually disabled.
296 """
297 auto_wrap_output_keys = getattr(estimator, "_sklearn_auto_wrap_output_keys", set())
298 return (
299 hasattr(estimator, "get_feature_names_out")
300 and "transform" in auto_wrap_output_keys
301 )
302
303
304class _SetOutputMixin:
305 """Mixin that dynamically wraps methods to return container based on config.
306
307 Currently `_SetOutputMixin` wraps `transform` and `fit_transform` and configures
308 it based on `set_output` of the global configuration.
309
310 `set_output` is only defined if `get_feature_names_out` is defined and
311 `auto_wrap_output_keys` is the default value.
312 """
313
314 def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs):
315 super().__init_subclass__(**kwargs)
316
317 # Dynamically wraps `transform` and `fit_transform` and configure it's
318 # output based on `set_output`.
319 if not (
320 isinstance(auto_wrap_output_keys, tuple) or auto_wrap_output_keys is None
321 ):
322 raise ValueError("auto_wrap_output_keys must be None or a tuple of keys.")
323
324 if auto_wrap_output_keys is None:
325 cls._sklearn_auto_wrap_output_keys = set()
326 return
327
328 # Mapping from method to key in configurations
329 method_to_key = {
330 "transform": "transform",
331 "fit_transform": "transform",
332 }
333 cls._sklearn_auto_wrap_output_keys = set()
334
335 for method, key in method_to_key.items():
336 if not hasattr(cls, method) or key not in auto_wrap_output_keys:
337 continue
338 cls._sklearn_auto_wrap_output_keys.add(key)
339
340 # Only wrap methods defined by cls itself
341 if method not in cls.__dict__:
342 continue
343 wrapped_method = _wrap_method_output(getattr(cls, method), key)
344 setattr(cls, method, wrapped_method)
345
346 @available_if(_auto_wrap_is_configured)
347 def set_output(self, *, transform=None):
348 """Set output container.
349
350 See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
351 for an example on how to use the API.
352
353 Parameters
354 ----------
355 transform : {"default", "pandas"}, default=None
356 Configure output of `transform` and `fit_transform`.
357
358 - `"default"`: Default output format of a transformer
359 - `"pandas"`: DataFrame output
360 - `"polars"`: Polars output
361 - `None`: Transform configuration is unchanged
362
363 .. versionadded:: 1.4
364 `"polars"` option was added.
365
366 Returns
367 -------
368 self : estimator instance
369 Estimator instance.
370 """
371 if transform is None:
372 return self
373
374 if not hasattr(self, "_sklearn_output_config"):
375 self._sklearn_output_config = {}
376
377 self._sklearn_output_config["transform"] = transform
378 return self
379
380
381def _safe_set_output(estimator, *, transform=None):
382 """Safely call estimator.set_output and error if it not available.
383
384 This is used by meta-estimators to set the output for child estimators.
385
386 Parameters
387 ----------
388 estimator : estimator instance
389 Estimator instance.
390
391 transform : {"default", "pandas"}, default=None
392 Configure output of the following estimator's methods:
393
394 - `"transform"`
395 - `"fit_transform"`
396
397 If `None`, this operation is a no-op.
398
399 Returns
400 -------
401 estimator : estimator instance
402 Estimator instance.
403 """
404 set_output_for_transform = (
405 hasattr(estimator, "transform")
406 or hasattr(estimator, "fit_transform")
407 and transform is not None
408 )
409 if not set_output_for_transform:
410 # If estimator can not transform, then `set_output` does not need to be
411 # called.
412 return
413
414 if not hasattr(estimator, "set_output"):
415 raise ValueError(
416 f"Unable to configure output for {estimator} because `set_output` "
417 "is not available."
418 )
419 return estimator.set_output(transform=transform)