1import html
2from contextlib import closing
3from inspect import isclass
4from io import StringIO
5from pathlib import Path
6from string import Template
7
8from .. import __version__, config_context
9from .fixes import parse_version
10
11
12class _IDCounter:
13 """Generate sequential ids with a prefix."""
14
15 def __init__(self, prefix):
16 self.prefix = prefix
17 self.count = 0
18
19 def get_id(self):
20 self.count += 1
21 return f"{self.prefix}-{self.count}"
22
23
24def _get_css_style():
25 return Path(__file__).with_suffix(".css").read_text(encoding="utf-8")
26
27
28_CONTAINER_ID_COUNTER = _IDCounter("sk-container-id")
29_ESTIMATOR_ID_COUNTER = _IDCounter("sk-estimator-id")
30_CSS_STYLE = _get_css_style()
31
32
33class _VisualBlock:
34 """HTML Representation of Estimator
35
36 Parameters
37 ----------
38 kind : {'serial', 'parallel', 'single'}
39 kind of HTML block
40
41 estimators : list of estimators or `_VisualBlock`s or a single estimator
42 If kind != 'single', then `estimators` is a list of
43 estimators.
44 If kind == 'single', then `estimators` is a single estimator.
45
46 names : list of str, default=None
47 If kind != 'single', then `names` corresponds to estimators.
48 If kind == 'single', then `names` is a single string corresponding to
49 the single estimator.
50
51 name_details : list of str, str, or None, default=None
52 If kind != 'single', then `name_details` corresponds to `names`.
53 If kind == 'single', then `name_details` is a single string
54 corresponding to the single estimator.
55
56 dash_wrapped : bool, default=True
57 If true, wrapped HTML element will be wrapped with a dashed border.
58 Only active when kind != 'single'.
59 """
60
61 def __init__(
62 self, kind, estimators, *, names=None, name_details=None, dash_wrapped=True
63 ):
64 self.kind = kind
65 self.estimators = estimators
66 self.dash_wrapped = dash_wrapped
67
68 if self.kind in ("parallel", "serial"):
69 if names is None:
70 names = (None,) * len(estimators)
71 if name_details is None:
72 name_details = (None,) * len(estimators)
73
74 self.names = names
75 self.name_details = name_details
76
77 def _sk_visual_block_(self):
78 return self
79
80
81def _write_label_html(
82 out,
83 name,
84 name_details,
85 outer_class="sk-label-container",
86 inner_class="sk-label",
87 checked=False,
88 doc_link="",
89 is_fitted_css_class="",
90 is_fitted_icon="",
91):
92 """Write labeled html with or without a dropdown with named details.
93
94 Parameters
95 ----------
96 out : file-like object
97 The file to write the HTML representation to.
98 name : str
99 The label for the estimator. It corresponds either to the estimator class name
100 for a simple estimator or in the case of a `Pipeline` and `ColumnTransformer`,
101 it corresponds to the name of the step.
102 name_details : str
103 The details to show as content in the dropdown part of the toggleable label. It
104 can contain information such as non-default parameters or column information for
105 `ColumnTransformer`.
106 outer_class : {"sk-label-container", "sk-item"}, default="sk-label-container"
107 The CSS class for the outer container.
108 inner_class : {"sk-label", "sk-estimator"}, default="sk-label"
109 The CSS class for the inner container.
110 checked : bool, default=False
111 Whether the dropdown is folded or not. With a single estimator, we intend to
112 unfold the content.
113 doc_link : str, default=""
114 The link to the documentation for the estimator. If an empty string, no link is
115 added to the diagram. This can be generated for an estimator if it uses the
116 `_HTMLDocumentationLinkMixin`.
117 is_fitted_css_class : {"", "fitted"}
118 The CSS class to indicate whether or not the estimator is fitted. The
119 empty string means that the estimator is not fitted and "fitted" means that the
120 estimator is fitted.
121 is_fitted_icon : str, default=""
122 The HTML representation to show the fitted information in the diagram. An empty
123 string means that no information is shown.
124 """
125 # we need to add some padding to the left of the label to be sure it is centered
126 padding_label = " " if is_fitted_icon else "" # add padding for the "i" char
127
128 out.write(
129 f'<div class="{outer_class}"><div'
130 f' class="{inner_class} {is_fitted_css_class} sk-toggleable">'
131 )
132 name = html.escape(name)
133
134 if name_details is not None:
135 name_details = html.escape(str(name_details))
136 label_class = (
137 f"sk-toggleable__label {is_fitted_css_class} sk-toggleable__label-arrow"
138 )
139
140 checked_str = "checked" if checked else ""
141 est_id = _ESTIMATOR_ID_COUNTER.get_id()
142
143 if doc_link:
144 doc_label = "<span>Online documentation</span>"
145 if name is not None:
146 doc_label = f"<span>Documentation for {name}</span>"
147 doc_link = (
148 f'<a class="sk-estimator-doc-link {is_fitted_css_class}"'
149 f' rel="noreferrer" target="_blank" href="{doc_link}">?{doc_label}</a>'
150 )
151 padding_label += " " # add additional padding for the "?" char
152
153 fmt_str = (
154 '<input class="sk-toggleable__control sk-hidden--visually"'
155 f' id="{est_id}" '
156 f'type="checkbox" {checked_str}><label for="{est_id}" '
157 f'class="{label_class} {is_fitted_css_class}">{padding_label}{name}'
158 f"{doc_link}{is_fitted_icon}</label><div "
159 f'class="sk-toggleable__content {is_fitted_css_class}">'
160 f"<pre>{name_details}</pre></div> "
161 )
162 out.write(fmt_str)
163 else:
164 out.write(f"<label>{name}</label>")
165 out.write("</div></div>") # outer_class inner_class
166
167
168def _get_visual_block(estimator):
169 """Generate information about how to display an estimator."""
170 if hasattr(estimator, "_sk_visual_block_"):
171 try:
172 return estimator._sk_visual_block_()
173 except Exception:
174 return _VisualBlock(
175 "single",
176 estimator,
177 names=estimator.__class__.__name__,
178 name_details=str(estimator),
179 )
180
181 if isinstance(estimator, str):
182 return _VisualBlock(
183 "single", estimator, names=estimator, name_details=estimator
184 )
185 elif estimator is None:
186 return _VisualBlock("single", estimator, names="None", name_details="None")
187
188 # check if estimator looks like a meta estimator (wraps estimators)
189 if hasattr(estimator, "get_params") and not isclass(estimator):
190 estimators = [
191 (key, est)
192 for key, est in estimator.get_params(deep=False).items()
193 if hasattr(est, "get_params") and hasattr(est, "fit") and not isclass(est)
194 ]
195 if estimators:
196 return _VisualBlock(
197 "parallel",
198 [est for _, est in estimators],
199 names=[f"{key}: {est.__class__.__name__}" for key, est in estimators],
200 name_details=[str(est) for _, est in estimators],
201 )
202
203 return _VisualBlock(
204 "single",
205 estimator,
206 names=estimator.__class__.__name__,
207 name_details=str(estimator),
208 )
209
210
211def _write_estimator_html(
212 out,
213 estimator,
214 estimator_label,
215 estimator_label_details,
216 is_fitted_css_class,
217 is_fitted_icon="",
218 first_call=False,
219):
220 """Write estimator to html in serial, parallel, or by itself (single).
221
222 For multiple estimators, this function is called recursively.
223
224 Parameters
225 ----------
226 out : file-like object
227 The file to write the HTML representation to.
228 estimator : estimator object
229 The estimator to visualize.
230 estimator_label : str
231 The label for the estimator. It corresponds either to the estimator class name
232 for simple estimator or in the case of `Pipeline` and `ColumnTransformer`, it
233 corresponds to the name of the step.
234 estimator_label_details : str
235 The details to show as content in the dropdown part of the toggleable label.
236 It can contain information as non-default parameters or column information for
237 `ColumnTransformer`.
238 is_fitted_css_class : {"", "fitted"}
239 The CSS class to indicate whether or not the estimator is fitted or not. The
240 empty string means that the estimator is not fitted and "fitted" means that the
241 estimator is fitted.
242 is_fitted_icon : str, default=""
243 The HTML representation to show the fitted information in the diagram. An empty
244 string means that no information is shown. If the estimator to be shown is not
245 the first estimator (i.e. `first_call=False`), `is_fitted_icon` is always an
246 empty string.
247 first_call : bool, default=False
248 Whether this is the first time this function is called.
249 """
250 if first_call:
251 est_block = _get_visual_block(estimator)
252 else:
253 is_fitted_icon = ""
254 with config_context(print_changed_only=True):
255 est_block = _get_visual_block(estimator)
256 # `estimator` can also be an instance of `_VisualBlock`
257 if hasattr(estimator, "_get_doc_link"):
258 doc_link = estimator._get_doc_link()
259 else:
260 doc_link = ""
261 if est_block.kind in ("serial", "parallel"):
262 dashed_wrapped = first_call or est_block.dash_wrapped
263 dash_cls = " sk-dashed-wrapped" if dashed_wrapped else ""
264 out.write(f'<div class="sk-item{dash_cls}">')
265
266 if estimator_label:
267 _write_label_html(
268 out,
269 estimator_label,
270 estimator_label_details,
271 doc_link=doc_link,
272 is_fitted_css_class=is_fitted_css_class,
273 is_fitted_icon=is_fitted_icon,
274 )
275
276 kind = est_block.kind
277 out.write(f'<div class="sk-{kind}">')
278 est_infos = zip(est_block.estimators, est_block.names, est_block.name_details)
279
280 for est, name, name_details in est_infos:
281 if kind == "serial":
282 _write_estimator_html(
283 out,
284 est,
285 name,
286 name_details,
287 is_fitted_css_class=is_fitted_css_class,
288 )
289 else: # parallel
290 out.write('<div class="sk-parallel-item">')
291 # wrap element in a serial visualblock
292 serial_block = _VisualBlock("serial", [est], dash_wrapped=False)
293 _write_estimator_html(
294 out,
295 serial_block,
296 name,
297 name_details,
298 is_fitted_css_class=is_fitted_css_class,
299 )
300 out.write("</div>") # sk-parallel-item
301
302 out.write("</div></div>")
303 elif est_block.kind == "single":
304 _write_label_html(
305 out,
306 est_block.names,
307 est_block.name_details,
308 outer_class="sk-item",
309 inner_class="sk-estimator",
310 checked=first_call,
311 doc_link=doc_link,
312 is_fitted_css_class=is_fitted_css_class,
313 is_fitted_icon=is_fitted_icon,
314 )
315
316
317def estimator_html_repr(estimator):
318 """Build a HTML representation of an estimator.
319
320 Read more in the :ref:`User Guide <visualizing_composite_estimators>`.
321
322 Parameters
323 ----------
324 estimator : estimator object
325 The estimator to visualize.
326
327 Returns
328 -------
329 html: str
330 HTML representation of estimator.
331 """
332 from sklearn.exceptions import NotFittedError
333 from sklearn.utils.validation import check_is_fitted
334
335 if not hasattr(estimator, "fit"):
336 status_label = "<span>Not fitted</span>"
337 is_fitted_css_class = ""
338 else:
339 try:
340 check_is_fitted(estimator)
341 status_label = "<span>Fitted</span>"
342 is_fitted_css_class = "fitted"
343 except NotFittedError:
344 status_label = "<span>Not fitted</span>"
345 is_fitted_css_class = ""
346
347 is_fitted_icon = (
348 f'<span class="sk-estimator-doc-link {is_fitted_css_class}">'
349 f"i{status_label}</span>"
350 )
351 with closing(StringIO()) as out:
352 container_id = _CONTAINER_ID_COUNTER.get_id()
353 style_template = Template(_CSS_STYLE)
354 style_with_id = style_template.substitute(id=container_id)
355 estimator_str = str(estimator)
356
357 # The fallback message is shown by default and loading the CSS sets
358 # div.sk-text-repr-fallback to display: none to hide the fallback message.
359 #
360 # If the notebook is trusted, the CSS is loaded which hides the fallback
361 # message. If the notebook is not trusted, then the CSS is not loaded and the
362 # fallback message is shown by default.
363 #
364 # The reverse logic applies to HTML repr div.sk-container.
365 # div.sk-container is hidden by default and the loading the CSS displays it.
366 fallback_msg = (
367 "In a Jupyter environment, please rerun this cell to show the HTML"
368 " representation or trust the notebook. <br />On GitHub, the"
369 " HTML representation is unable to render, please try loading this page"
370 " with nbviewer.org."
371 )
372 html_template = (
373 f"<style>{style_with_id}</style>"
374 f'<div id="{container_id}" class="sk-top-container">'
375 '<div class="sk-text-repr-fallback">'
376 f"<pre>{html.escape(estimator_str)}</pre><b>{fallback_msg}</b>"
377 "</div>"
378 '<div class="sk-container" hidden>'
379 )
380
381 out.write(html_template)
382
383 _write_estimator_html(
384 out,
385 estimator,
386 estimator.__class__.__name__,
387 estimator_str,
388 first_call=True,
389 is_fitted_css_class=is_fitted_css_class,
390 is_fitted_icon=is_fitted_icon,
391 )
392 out.write("</div></div>")
393
394 html_output = out.getvalue()
395 return html_output
396
397
398class _HTMLDocumentationLinkMixin:
399 """Mixin class allowing to generate a link to the API documentation.
400
401 This mixin relies on three attributes:
402 - `_doc_link_module`: it corresponds to the root module (e.g. `sklearn`). Using this
403 mixin, the default value is `sklearn`.
404 - `_doc_link_template`: it corresponds to the template used to generate the
405 link to the API documentation. Using this mixin, the default value is
406 `"https://scikit-learn.org/{version_url}/modules/generated/
407 {estimator_module}.{estimator_name}.html"`.
408 - `_doc_link_url_param_generator`: it corresponds to a function that generates the
409 parameters to be used in the template when the estimator module and name are not
410 sufficient.
411
412 The method :meth:`_get_doc_link` generates the link to the API documentation for a
413 given estimator.
414
415 This useful provides all the necessary states for
416 :func:`sklearn.utils.estimator_html_repr` to generate a link to the API
417 documentation for the estimator HTML diagram.
418
419 Examples
420 --------
421 If the default values for `_doc_link_module`, `_doc_link_template` are not suitable,
422 then you can override them:
423 >>> from sklearn.base import BaseEstimator
424 >>> estimator = BaseEstimator()
425 >>> estimator._doc_link_template = "https://website.com/{single_param}.html"
426 >>> def url_param_generator(estimator):
427 ... return {"single_param": estimator.__class__.__name__}
428 >>> estimator._doc_link_url_param_generator = url_param_generator
429 >>> estimator._get_doc_link()
430 'https://website.com/BaseEstimator.html'
431 """
432
433 _doc_link_module = "sklearn"
434 _doc_link_url_param_generator = None
435
436 @property
437 def _doc_link_template(self):
438 sklearn_version = parse_version(__version__)
439 if sklearn_version.dev is None:
440 version_url = f"{sklearn_version.major}.{sklearn_version.minor}"
441 else:
442 version_url = "dev"
443 return getattr(
444 self,
445 "__doc_link_template",
446 (
447 f"https://scikit-learn.org/{version_url}/modules/generated/"
448 "{estimator_module}.{estimator_name}.html"
449 ),
450 )
451
452 @_doc_link_template.setter
453 def _doc_link_template(self, value):
454 setattr(self, "__doc_link_template", value)
455
456 def _get_doc_link(self):
457 """Generates a link to the API documentation for a given estimator.
458
459 This method generates the link to the estimator's documentation page
460 by using the template defined by the attribute `_doc_link_template`.
461
462 Returns
463 -------
464 url : str
465 The URL to the API documentation for this estimator. If the estimator does
466 not belong to module `_doc_link_module`, the empty string (i.e. `""`) is
467 returned.
468 """
469 if self.__class__.__module__.split(".")[0] != self._doc_link_module:
470 return ""
471
472 if self._doc_link_url_param_generator is None:
473 estimator_name = self.__class__.__name__
474 estimator_module = ".".join(
475 [
476 _
477 for _ in self.__class__.__module__.split(".")
478 if not _.startswith("_")
479 ]
480 )
481 return self._doc_link_template.format(
482 estimator_module=estimator_module, estimator_name=estimator_name
483 )
484 return self._doc_link_template.format(
485 **self._doc_link_url_param_generator(self)
486 )