Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scikit_learn-1.4.dev0-py3.8-linux-x86_64.egg/sklearn/utils/discovery.py: 12%

83 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1""" 

2The :mod:`sklearn.utils.discovery` module includes utilities to discover 

3objects (i.e. estimators, displays, functions) from the `sklearn` package. 

4""" 

5 

6import inspect 

7import pkgutil 

8from importlib import import_module 

9from operator import itemgetter 

10from pathlib import Path 

11 

12_MODULE_TO_IGNORE = { 

13 "tests", 

14 "externals", 

15 "setup", 

16 "conftest", 

17 "experimental", 

18 "estimator_checks", 

19} 

20 

21 

22def all_estimators(type_filter=None): 

23 """Get a list of all estimators from `sklearn`. 

24 

25 This function crawls the module and gets all classes that inherit 

26 from BaseEstimator. Classes that are defined in test-modules are not 

27 included. 

28 

29 Parameters 

30 ---------- 

31 type_filter : {"classifier", "regressor", "cluster", "transformer"} \ 

32 or list of such str, default=None 

33 Which kind of estimators should be returned. If None, no filter is 

34 applied and all estimators are returned. Possible values are 

35 'classifier', 'regressor', 'cluster' and 'transformer' to get 

36 estimators only of these specific types, or a list of these to 

37 get the estimators that fit at least one of the types. 

38 

39 Returns 

40 ------- 

41 estimators : list of tuples 

42 List of (name, class), where ``name`` is the class name as string 

43 and ``class`` is the actual type of the class. 

44 """ 

45 # lazy import to avoid circular imports from sklearn.base 

46 from ..base import ( 

47 BaseEstimator, 

48 ClassifierMixin, 

49 ClusterMixin, 

50 RegressorMixin, 

51 TransformerMixin, 

52 ) 

53 from . import IS_PYPY 

54 from ._testing import ignore_warnings 

55 

56 def is_abstract(c): 

57 if not (hasattr(c, "__abstractmethods__")): 

58 return False 

59 if not len(c.__abstractmethods__): 

60 return False 

61 return True 

62 

63 all_classes = [] 

64 root = str(Path(__file__).parent.parent) # sklearn package 

65 # Ignore deprecation warnings triggered at import time and from walking 

66 # packages 

67 with ignore_warnings(category=FutureWarning): 

68 for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."): 

69 module_parts = module_name.split(".") 

70 if ( 

71 any(part in _MODULE_TO_IGNORE for part in module_parts) 

72 or "._" in module_name 

73 ): 

74 continue 

75 module = import_module(module_name) 

76 classes = inspect.getmembers(module, inspect.isclass) 

77 classes = [ 

78 (name, est_cls) for name, est_cls in classes if not name.startswith("_") 

79 ] 

80 

81 # TODO: Remove when FeatureHasher is implemented in PYPY 

82 # Skips FeatureHasher for PYPY 

83 if IS_PYPY and "feature_extraction" in module_name: 

84 classes = [ 

85 (name, est_cls) 

86 for name, est_cls in classes 

87 if name == "FeatureHasher" 

88 ] 

89 

90 all_classes.extend(classes) 

91 

92 all_classes = set(all_classes) 

93 

94 estimators = [ 

95 c 

96 for c in all_classes 

97 if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator") 

98 ] 

99 # get rid of abstract base classes 

100 estimators = [c for c in estimators if not is_abstract(c[1])] 

101 

102 if type_filter is not None: 

103 if not isinstance(type_filter, list): 

104 type_filter = [type_filter] 

105 else: 

106 type_filter = list(type_filter) # copy 

107 filtered_estimators = [] 

108 filters = { 

109 "classifier": ClassifierMixin, 

110 "regressor": RegressorMixin, 

111 "transformer": TransformerMixin, 

112 "cluster": ClusterMixin, 

113 } 

114 for name, mixin in filters.items(): 

115 if name in type_filter: 

116 type_filter.remove(name) 

117 filtered_estimators.extend( 

118 [est for est in estimators if issubclass(est[1], mixin)] 

119 ) 

120 estimators = filtered_estimators 

121 if type_filter: 

122 raise ValueError( 

123 "Parameter type_filter must be 'classifier', " 

124 "'regressor', 'transformer', 'cluster' or " 

125 "None, got" 

126 f" {repr(type_filter)}." 

127 ) 

128 

129 # drop duplicates, sort for reproducibility 

130 # itemgetter is used to ensure the sort does not extend to the 2nd item of 

131 # the tuple 

132 return sorted(set(estimators), key=itemgetter(0)) 

133 

134 

135def all_displays(): 

136 """Get a list of all displays from `sklearn`. 

137 

138 Returns 

139 ------- 

140 displays : list of tuples 

141 List of (name, class), where ``name`` is the display class name as 

142 string and ``class`` is the actual type of the class. 

143 """ 

144 # lazy import to avoid circular imports from sklearn.base 

145 from ._testing import ignore_warnings 

146 

147 all_classes = [] 

148 root = str(Path(__file__).parent.parent) # sklearn package 

149 # Ignore deprecation warnings triggered at import time and from walking 

150 # packages 

151 with ignore_warnings(category=FutureWarning): 

152 for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."): 

153 module_parts = module_name.split(".") 

154 if ( 

155 any(part in _MODULE_TO_IGNORE for part in module_parts) 

156 or "._" in module_name 

157 ): 

158 continue 

159 module = import_module(module_name) 

160 classes = inspect.getmembers(module, inspect.isclass) 

161 classes = [ 

162 (name, display_class) 

163 for name, display_class in classes 

164 if not name.startswith("_") and name.endswith("Display") 

165 ] 

166 all_classes.extend(classes) 

167 

168 return sorted(set(all_classes), key=itemgetter(0)) 

169 

170 

171def _is_checked_function(item): 

172 if not inspect.isfunction(item): 

173 return False 

174 

175 if item.__name__.startswith("_"): 

176 return False 

177 

178 mod = item.__module__ 

179 if not mod.startswith("sklearn.") or mod.endswith("estimator_checks"): 

180 return False 

181 

182 return True 

183 

184 

185def all_functions(): 

186 """Get a list of all functions from `sklearn`. 

187 

188 Returns 

189 ------- 

190 functions : list of tuples 

191 List of (name, function), where ``name`` is the function name as 

192 string and ``function`` is the actual function. 

193 """ 

194 # lazy import to avoid circular imports from sklearn.base 

195 from ._testing import ignore_warnings 

196 

197 all_functions = [] 

198 root = str(Path(__file__).parent.parent) # sklearn package 

199 # Ignore deprecation warnings triggered at import time and from walking 

200 # packages 

201 with ignore_warnings(category=FutureWarning): 

202 for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."): 

203 module_parts = module_name.split(".") 

204 if ( 

205 any(part in _MODULE_TO_IGNORE for part in module_parts) 

206 or "._" in module_name 

207 ): 

208 continue 

209 

210 module = import_module(module_name) 

211 functions = inspect.getmembers(module, _is_checked_function) 

212 functions = [ 

213 (func.__name__, func) 

214 for name, func in functions 

215 if not name.startswith("_") 

216 ] 

217 all_functions.extend(functions) 

218 

219 # drop duplicates, sort for reproducibility 

220 # itemgetter is used to ensure the sort does not extend to the 2nd item of 

221 # the tuple 

222 return sorted(set(all_functions), key=itemgetter(0))