Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scikit_learn-1.4.dev0-py3.8-linux-x86_64.egg/sklearn/utils/metaestimators.py: 26%

70 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1""" 

2The :mod:`sklearn.utils.metaestimators` module includes utilities for meta-estimators. 

3""" 

4 

5# Author: Joel Nothman 

6# Andreas Mueller 

7# License: BSD 

8from abc import ABCMeta, abstractmethod 

9from contextlib import suppress 

10from typing import Any, List 

11 

12import numpy as np 

13 

14from ..base import BaseEstimator 

15from ..utils import _safe_indexing 

16from ..utils._tags import _safe_tags 

17from ._available_if import available_if 

18 

19__all__ = ["available_if"] 

20 

21 

22class _BaseComposition(BaseEstimator, metaclass=ABCMeta): 

23 """Handles parameter management for classifiers composed of named estimators.""" 

24 

25 steps: List[Any] 

26 

27 @abstractmethod 

28 def __init__(self): 

29 pass 

30 

31 def _get_params(self, attr, deep=True): 

32 out = super().get_params(deep=deep) 

33 if not deep: 

34 return out 

35 

36 estimators = getattr(self, attr) 

37 try: 

38 out.update(estimators) 

39 except (TypeError, ValueError): 

40 # Ignore TypeError for cases where estimators is not a list of 

41 # (name, estimator) and ignore ValueError when the list is not 

42 # formatted correctly. This is to prevent errors when calling 

43 # `set_params`. `BaseEstimator.set_params` calls `get_params` which 

44 # can error for invalid values for `estimators`. 

45 return out 

46 

47 for name, estimator in estimators: 

48 if hasattr(estimator, "get_params"): 

49 for key, value in estimator.get_params(deep=True).items(): 

50 out["%s__%s" % (name, key)] = value 

51 return out 

52 

53 def _set_params(self, attr, **params): 

54 # Ensure strict ordering of parameter setting: 

55 # 1. All steps 

56 if attr in params: 

57 setattr(self, attr, params.pop(attr)) 

58 # 2. Replace items with estimators in params 

59 items = getattr(self, attr) 

60 if isinstance(items, list) and items: 

61 # Get item names used to identify valid names in params 

62 # `zip` raises a TypeError when `items` does not contains 

63 # elements of length 2 

64 with suppress(TypeError): 

65 item_names, _ = zip(*items) 

66 for name in list(params.keys()): 

67 if "__" not in name and name in item_names: 

68 self._replace_estimator(attr, name, params.pop(name)) 

69 

70 # 3. Step parameters and other initialisation arguments 

71 super().set_params(**params) 

72 return self 

73 

74 def _replace_estimator(self, attr, name, new_val): 

75 # assumes `name` is a valid estimator name 

76 new_estimators = list(getattr(self, attr)) 

77 for i, (estimator_name, _) in enumerate(new_estimators): 

78 if estimator_name == name: 

79 new_estimators[i] = (name, new_val) 

80 break 

81 setattr(self, attr, new_estimators) 

82 

83 def _validate_names(self, names): 

84 if len(set(names)) != len(names): 

85 raise ValueError("Names provided are not unique: {0!r}".format(list(names))) 

86 invalid_names = set(names).intersection(self.get_params(deep=False)) 

87 if invalid_names: 

88 raise ValueError( 

89 "Estimator names conflict with constructor arguments: {0!r}".format( 

90 sorted(invalid_names) 

91 ) 

92 ) 

93 invalid_names = [name for name in names if "__" in name] 

94 if invalid_names: 

95 raise ValueError( 

96 "Estimator names must not contain __: got {0!r}".format(invalid_names) 

97 ) 

98 

99 

100def _safe_split(estimator, X, y, indices, train_indices=None): 

101 """Create subset of dataset and properly handle kernels. 

102 

103 Slice X, y according to indices for cross-validation, but take care of 

104 precomputed kernel-matrices or pairwise affinities / distances. 

105 

106 If ``estimator._pairwise is True``, X needs to be square and 

107 we slice rows and columns. If ``train_indices`` is not None, 

108 we slice rows using ``indices`` (assumed the test set) and columns 

109 using ``train_indices``, indicating the training set. 

110 

111 Labels y will always be indexed only along the first axis. 

112 

113 Parameters 

114 ---------- 

115 estimator : object 

116 Estimator to determine whether we should slice only rows or rows and 

117 columns. 

118 

119 X : array-like, sparse matrix or iterable 

120 Data to be indexed. If ``estimator._pairwise is True``, 

121 this needs to be a square array-like or sparse matrix. 

122 

123 y : array-like, sparse matrix or iterable 

124 Targets to be indexed. 

125 

126 indices : array of int 

127 Rows to select from X and y. 

128 If ``estimator._pairwise is True`` and ``train_indices is None`` 

129 then ``indices`` will also be used to slice columns. 

130 

131 train_indices : array of int or None, default=None 

132 If ``estimator._pairwise is True`` and ``train_indices is not None``, 

133 then ``train_indices`` will be use to slice the columns of X. 

134 

135 Returns 

136 ------- 

137 X_subset : array-like, sparse matrix or list 

138 Indexed data. 

139 

140 y_subset : array-like, sparse matrix or list 

141 Indexed targets. 

142 

143 """ 

144 if _safe_tags(estimator, key="pairwise"): 

145 if not hasattr(X, "shape"): 

146 raise ValueError( 

147 "Precomputed kernels or affinity matrices have " 

148 "to be passed as arrays or sparse matrices." 

149 ) 

150 # X is a precomputed square kernel matrix 

151 if X.shape[0] != X.shape[1]: 

152 raise ValueError("X should be a square kernel matrix") 

153 if train_indices is None: 

154 X_subset = X[np.ix_(indices, indices)] 

155 else: 

156 X_subset = X[np.ix_(indices, train_indices)] 

157 else: 

158 X_subset = _safe_indexing(X, indices) 

159 

160 if y is not None: 

161 y_subset = _safe_indexing(y, indices) 

162 else: 

163 y_subset = None 

164 

165 return X_subset, y_subset