1"""
2The :mod:`sklearn.utils.metaestimators` module includes utilities for meta-estimators.
3"""
4
5# Author: Joel Nothman
6# Andreas Mueller
7# License: BSD
8from abc import ABCMeta, abstractmethod
9from contextlib import suppress
10from typing import Any, List
11
12import numpy as np
13
14from ..base import BaseEstimator
15from ..utils import _safe_indexing
16from ..utils._tags import _safe_tags
17from ._available_if import available_if
18
19__all__ = ["available_if"]
20
21
22class _BaseComposition(BaseEstimator, metaclass=ABCMeta):
23 """Handles parameter management for classifiers composed of named estimators."""
24
25 steps: List[Any]
26
27 @abstractmethod
28 def __init__(self):
29 pass
30
31 def _get_params(self, attr, deep=True):
32 out = super().get_params(deep=deep)
33 if not deep:
34 return out
35
36 estimators = getattr(self, attr)
37 try:
38 out.update(estimators)
39 except (TypeError, ValueError):
40 # Ignore TypeError for cases where estimators is not a list of
41 # (name, estimator) and ignore ValueError when the list is not
42 # formatted correctly. This is to prevent errors when calling
43 # `set_params`. `BaseEstimator.set_params` calls `get_params` which
44 # can error for invalid values for `estimators`.
45 return out
46
47 for name, estimator in estimators:
48 if hasattr(estimator, "get_params"):
49 for key, value in estimator.get_params(deep=True).items():
50 out["%s__%s" % (name, key)] = value
51 return out
52
53 def _set_params(self, attr, **params):
54 # Ensure strict ordering of parameter setting:
55 # 1. All steps
56 if attr in params:
57 setattr(self, attr, params.pop(attr))
58 # 2. Replace items with estimators in params
59 items = getattr(self, attr)
60 if isinstance(items, list) and items:
61 # Get item names used to identify valid names in params
62 # `zip` raises a TypeError when `items` does not contains
63 # elements of length 2
64 with suppress(TypeError):
65 item_names, _ = zip(*items)
66 for name in list(params.keys()):
67 if "__" not in name and name in item_names:
68 self._replace_estimator(attr, name, params.pop(name))
69
70 # 3. Step parameters and other initialisation arguments
71 super().set_params(**params)
72 return self
73
74 def _replace_estimator(self, attr, name, new_val):
75 # assumes `name` is a valid estimator name
76 new_estimators = list(getattr(self, attr))
77 for i, (estimator_name, _) in enumerate(new_estimators):
78 if estimator_name == name:
79 new_estimators[i] = (name, new_val)
80 break
81 setattr(self, attr, new_estimators)
82
83 def _validate_names(self, names):
84 if len(set(names)) != len(names):
85 raise ValueError("Names provided are not unique: {0!r}".format(list(names)))
86 invalid_names = set(names).intersection(self.get_params(deep=False))
87 if invalid_names:
88 raise ValueError(
89 "Estimator names conflict with constructor arguments: {0!r}".format(
90 sorted(invalid_names)
91 )
92 )
93 invalid_names = [name for name in names if "__" in name]
94 if invalid_names:
95 raise ValueError(
96 "Estimator names must not contain __: got {0!r}".format(invalid_names)
97 )
98
99
100def _safe_split(estimator, X, y, indices, train_indices=None):
101 """Create subset of dataset and properly handle kernels.
102
103 Slice X, y according to indices for cross-validation, but take care of
104 precomputed kernel-matrices or pairwise affinities / distances.
105
106 If ``estimator._pairwise is True``, X needs to be square and
107 we slice rows and columns. If ``train_indices`` is not None,
108 we slice rows using ``indices`` (assumed the test set) and columns
109 using ``train_indices``, indicating the training set.
110
111 Labels y will always be indexed only along the first axis.
112
113 Parameters
114 ----------
115 estimator : object
116 Estimator to determine whether we should slice only rows or rows and
117 columns.
118
119 X : array-like, sparse matrix or iterable
120 Data to be indexed. If ``estimator._pairwise is True``,
121 this needs to be a square array-like or sparse matrix.
122
123 y : array-like, sparse matrix or iterable
124 Targets to be indexed.
125
126 indices : array of int
127 Rows to select from X and y.
128 If ``estimator._pairwise is True`` and ``train_indices is None``
129 then ``indices`` will also be used to slice columns.
130
131 train_indices : array of int or None, default=None
132 If ``estimator._pairwise is True`` and ``train_indices is not None``,
133 then ``train_indices`` will be use to slice the columns of X.
134
135 Returns
136 -------
137 X_subset : array-like, sparse matrix or list
138 Indexed data.
139
140 y_subset : array-like, sparse matrix or list
141 Indexed targets.
142
143 """
144 if _safe_tags(estimator, key="pairwise"):
145 if not hasattr(X, "shape"):
146 raise ValueError(
147 "Precomputed kernels or affinity matrices have "
148 "to be passed as arrays or sparse matrices."
149 )
150 # X is a precomputed square kernel matrix
151 if X.shape[0] != X.shape[1]:
152 raise ValueError("X should be a square kernel matrix")
153 if train_indices is None:
154 X_subset = X[np.ix_(indices, indices)]
155 else:
156 X_subset = X[np.ix_(indices, train_indices)]
157 else:
158 X_subset = _safe_indexing(X, indices)
159
160 if y is not None:
161 y_subset = _safe_indexing(y, indices)
162 else:
163 y_subset = None
164
165 return X_subset, y_subset