1"""
2The :mod:`sklearn.utils.discovery` module includes utilities to discover
3objects (i.e. estimators, displays, functions) from the `sklearn` package.
4"""
5
6import inspect
7import pkgutil
8from importlib import import_module
9from operator import itemgetter
10from pathlib import Path
11
12_MODULE_TO_IGNORE = {
13 "tests",
14 "externals",
15 "setup",
16 "conftest",
17 "experimental",
18 "estimator_checks",
19}
20
21
22def all_estimators(type_filter=None):
23 """Get a list of all estimators from `sklearn`.
24
25 This function crawls the module and gets all classes that inherit
26 from BaseEstimator. Classes that are defined in test-modules are not
27 included.
28
29 Parameters
30 ----------
31 type_filter : {"classifier", "regressor", "cluster", "transformer"} \
32 or list of such str, default=None
33 Which kind of estimators should be returned. If None, no filter is
34 applied and all estimators are returned. Possible values are
35 'classifier', 'regressor', 'cluster' and 'transformer' to get
36 estimators only of these specific types, or a list of these to
37 get the estimators that fit at least one of the types.
38
39 Returns
40 -------
41 estimators : list of tuples
42 List of (name, class), where ``name`` is the class name as string
43 and ``class`` is the actual type of the class.
44 """
45 # lazy import to avoid circular imports from sklearn.base
46 from ..base import (
47 BaseEstimator,
48 ClassifierMixin,
49 ClusterMixin,
50 RegressorMixin,
51 TransformerMixin,
52 )
53 from . import IS_PYPY
54 from ._testing import ignore_warnings
55
56 def is_abstract(c):
57 if not (hasattr(c, "__abstractmethods__")):
58 return False
59 if not len(c.__abstractmethods__):
60 return False
61 return True
62
63 all_classes = []
64 root = str(Path(__file__).parent.parent) # sklearn package
65 # Ignore deprecation warnings triggered at import time and from walking
66 # packages
67 with ignore_warnings(category=FutureWarning):
68 for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."):
69 module_parts = module_name.split(".")
70 if (
71 any(part in _MODULE_TO_IGNORE for part in module_parts)
72 or "._" in module_name
73 ):
74 continue
75 module = import_module(module_name)
76 classes = inspect.getmembers(module, inspect.isclass)
77 classes = [
78 (name, est_cls) for name, est_cls in classes if not name.startswith("_")
79 ]
80
81 # TODO: Remove when FeatureHasher is implemented in PYPY
82 # Skips FeatureHasher for PYPY
83 if IS_PYPY and "feature_extraction" in module_name:
84 classes = [
85 (name, est_cls)
86 for name, est_cls in classes
87 if name == "FeatureHasher"
88 ]
89
90 all_classes.extend(classes)
91
92 all_classes = set(all_classes)
93
94 estimators = [
95 c
96 for c in all_classes
97 if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")
98 ]
99 # get rid of abstract base classes
100 estimators = [c for c in estimators if not is_abstract(c[1])]
101
102 if type_filter is not None:
103 if not isinstance(type_filter, list):
104 type_filter = [type_filter]
105 else:
106 type_filter = list(type_filter) # copy
107 filtered_estimators = []
108 filters = {
109 "classifier": ClassifierMixin,
110 "regressor": RegressorMixin,
111 "transformer": TransformerMixin,
112 "cluster": ClusterMixin,
113 }
114 for name, mixin in filters.items():
115 if name in type_filter:
116 type_filter.remove(name)
117 filtered_estimators.extend(
118 [est for est in estimators if issubclass(est[1], mixin)]
119 )
120 estimators = filtered_estimators
121 if type_filter:
122 raise ValueError(
123 "Parameter type_filter must be 'classifier', "
124 "'regressor', 'transformer', 'cluster' or "
125 "None, got"
126 f" {repr(type_filter)}."
127 )
128
129 # drop duplicates, sort for reproducibility
130 # itemgetter is used to ensure the sort does not extend to the 2nd item of
131 # the tuple
132 return sorted(set(estimators), key=itemgetter(0))
133
134
135def all_displays():
136 """Get a list of all displays from `sklearn`.
137
138 Returns
139 -------
140 displays : list of tuples
141 List of (name, class), where ``name`` is the display class name as
142 string and ``class`` is the actual type of the class.
143 """
144 # lazy import to avoid circular imports from sklearn.base
145 from ._testing import ignore_warnings
146
147 all_classes = []
148 root = str(Path(__file__).parent.parent) # sklearn package
149 # Ignore deprecation warnings triggered at import time and from walking
150 # packages
151 with ignore_warnings(category=FutureWarning):
152 for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."):
153 module_parts = module_name.split(".")
154 if (
155 any(part in _MODULE_TO_IGNORE for part in module_parts)
156 or "._" in module_name
157 ):
158 continue
159 module = import_module(module_name)
160 classes = inspect.getmembers(module, inspect.isclass)
161 classes = [
162 (name, display_class)
163 for name, display_class in classes
164 if not name.startswith("_") and name.endswith("Display")
165 ]
166 all_classes.extend(classes)
167
168 return sorted(set(all_classes), key=itemgetter(0))
169
170
171def _is_checked_function(item):
172 if not inspect.isfunction(item):
173 return False
174
175 if item.__name__.startswith("_"):
176 return False
177
178 mod = item.__module__
179 if not mod.startswith("sklearn.") or mod.endswith("estimator_checks"):
180 return False
181
182 return True
183
184
185def all_functions():
186 """Get a list of all functions from `sklearn`.
187
188 Returns
189 -------
190 functions : list of tuples
191 List of (name, function), where ``name`` is the function name as
192 string and ``function`` is the actual function.
193 """
194 # lazy import to avoid circular imports from sklearn.base
195 from ._testing import ignore_warnings
196
197 all_functions = []
198 root = str(Path(__file__).parent.parent) # sklearn package
199 # Ignore deprecation warnings triggered at import time and from walking
200 # packages
201 with ignore_warnings(category=FutureWarning):
202 for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."):
203 module_parts = module_name.split(".")
204 if (
205 any(part in _MODULE_TO_IGNORE for part in module_parts)
206 or "._" in module_name
207 ):
208 continue
209
210 module = import_module(module_name)
211 functions = inspect.getmembers(module, _is_checked_function)
212 functions = [
213 (func.__name__, func)
214 for name, func in functions
215 if not name.startswith("_")
216 ]
217 all_functions.extend(functions)
218
219 # drop duplicates, sort for reproducibility
220 # itemgetter is used to ensure the sort does not extend to the 2nd item of
221 # the tuple
222 return sorted(set(all_functions), key=itemgetter(0))