1"""
2Metadata Routing Utility
3
4In order to better understand the components implemented in this file, one
5needs to understand their relationship to one another.
6
7The only relevant public API for end users are the ``set_{method}_request``,
8e.g. ``estimator.set_fit_request(sample_weight=True)``. However, third-party
9developers and users who implement custom meta-estimators, need to deal with
10the objects implemented in this file.
11
12All estimators (should) implement a ``get_metadata_routing`` method, returning
13the routing requests set for the estimator. This method is automatically
14implemented via ``BaseEstimator`` for all simple estimators, but needs a custom
15implementation for meta-estimators.
16
17In non-routing consumers, i.e. the simplest case, e.g. ``SVM``,
18``get_metadata_routing`` returns a ``MetadataRequest`` object.
19
20In routers, e.g. meta-estimators and a multi metric scorer,
21``get_metadata_routing`` returns a ``MetadataRouter`` object.
22
23An object which is both a router and a consumer, e.g. a meta-estimator which
24consumes ``sample_weight`` and routes ``sample_weight`` to its sub-estimators,
25routing information includes both information about the object itself (added
26via ``MetadataRouter.add_self_request``), as well as the routing information
27for its sub-estimators.
28
29A ``MetadataRequest`` instance includes one ``MethodMetadataRequest`` per
30method in ``METHODS``, which includes ``fit``, ``score``, etc.
31
32Request values are added to the routing mechanism by adding them to
33``MethodMetadataRequest`` instances, e.g.
34``metadatarequest.fit.add(param="sample_weight", alias="my_weights")``. This is
35used in ``set_{method}_request`` which are automatically generated, so users
36and developers almost never need to directly call methods on a
37``MethodMetadataRequest``.
38
39The ``alias`` above in the ``add`` method has to be either a string (an alias),
40or a {True (requested), False (unrequested), None (error if passed)}``. There
41are some other special values such as ``UNUSED`` and ``WARN`` which are used
42for purposes such as warning of removing a metadata in a child class, but not
43used by the end users.
44
45``MetadataRouter`` includes information about sub-objects' routing and how
46methods are mapped together. For instance, the information about which methods
47of a sub-estimator are called in which methods of the meta-estimator are all
48stored here. Conceptually, this information looks like:
49
50```
51{
52 "sub_estimator1": (
53 mapping=[(caller="fit", callee="transform"), ...],
54 router=MetadataRequest(...), # or another MetadataRouter
55 ),
56 ...
57}
58```
59
60To give the above representation some structure, we use the following objects:
61
62- ``(caller, callee)`` is a namedtuple called ``MethodPair``
63
64- The list of ``MethodPair`` stored in the ``mapping`` field is a
65 ``MethodMapping`` object
66
67- ``(mapping=..., router=...)`` is a namedtuple called ``RouterMappingPair``
68
69The ``set_{method}_request`` methods are dynamically generated for estimators
70which inherit from the ``BaseEstimator``. This is done by attaching instances
71of the ``RequestMethod`` descriptor to classes, which is done in the
72``_MetadataRequester`` class, and ``BaseEstimator`` inherits from this mixin.
73This mixin also implements the ``get_metadata_routing``, which meta-estimators
74need to override, but it works for simple consumers as is.
75"""
76
77# Author: Adrin Jalali <adrin.jalali@gmail.com>
78# License: BSD 3 clause
79
80import inspect
81from collections import namedtuple
82from copy import deepcopy
83from typing import TYPE_CHECKING, Optional, Union
84from warnings import warn
85
86from .. import get_config
87from ..exceptions import UnsetMetadataPassedError
88from ._bunch import Bunch
89
90# Only the following methods are supported in the routing mechanism. Adding new
91# methods at the moment involves monkeypatching this list.
92# Note that if this list is changed or monkeypatched, the corresponding method
93# needs to be added under a TYPE_CHECKING condition like the one done here in
94# _MetadataRequester
95SIMPLE_METHODS = [
96 "fit",
97 "partial_fit",
98 "predict",
99 "predict_proba",
100 "predict_log_proba",
101 "decision_function",
102 "score",
103 "split",
104 "transform",
105 "inverse_transform",
106]
107
108# These methods are a composite of other methods and one cannot set their
109# requests directly. Instead they should be set by setting the requests of the
110# simple methods which make the composite ones.
111COMPOSITE_METHODS = {
112 "fit_transform": ["fit", "transform"],
113 "fit_predict": ["fit", "predict"],
114}
115
116METHODS = SIMPLE_METHODS + list(COMPOSITE_METHODS.keys())
117
118
119def _routing_enabled():
120 """Return whether metadata routing is enabled.
121
122 .. versionadded:: 1.3
123
124 Returns
125 -------
126 enabled : bool
127 Whether metadata routing is enabled. If the config is not set, it
128 defaults to False.
129 """
130 return get_config().get("enable_metadata_routing", False)
131
132
133def _raise_for_params(params, owner, method):
134 """Raise an error if metadata routing is not enabled and params are passed.
135
136 .. versionadded:: 1.4
137
138 Parameters
139 ----------
140 params : dict
141 The metadata passed to a method.
142
143 owner : object
144 The object to which the method belongs.
145
146 method : str
147 The name of the method, e.g. "fit".
148
149 Raises
150 ------
151 ValueError
152 If metadata routing is not enabled and params are passed.
153 """
154 caller = (
155 f"{owner.__class__.__name__}.{method}" if method else owner.__class__.__name__
156 )
157 if not _routing_enabled() and params:
158 raise ValueError(
159 f"Passing extra keyword arguments to {caller} is only supported if"
160 " enable_metadata_routing=True, which you can set using"
161 " `sklearn.set_config`. See the User Guide"
162 " <https://scikit-learn.org/stable/metadata_routing.html> for more"
163 f" details. Extra parameters passed are: {set(params)}"
164 )
165
166
167def _raise_for_unsupported_routing(obj, method, **kwargs):
168 """Raise when metadata routing is enabled and metadata is passed.
169
170 This is used in meta-estimators which have not implemented metadata routing
171 to prevent silent bugs. There is no need to use this function if the
172 meta-estimator is not accepting any metadata, especially in `fit`, since
173 if a meta-estimator accepts any metadata, they would do that in `fit` as
174 well.
175
176 Parameters
177 ----------
178 obj : estimator
179 The estimator for which we're raising the error.
180
181 method : str
182 The method where the error is raised.
183
184 **kwargs : dict
185 The metadata passed to the method.
186 """
187 kwargs = {key: value for key, value in kwargs.items() if value is not None}
188 if _routing_enabled() and kwargs:
189 cls_name = obj.__class__.__name__
190 raise NotImplementedError(
191 f"{cls_name}.{method} cannot accept given metadata ({set(kwargs.keys())})"
192 f" since metadata routing is not yet implemented for {cls_name}."
193 )
194
195
196class _RoutingNotSupportedMixin:
197 """A mixin to be used to remove the default `get_metadata_routing`.
198
199 This is used in meta-estimators where metadata routing is not yet
200 implemented.
201
202 This also makes it clear in our rendered documentation that this method
203 cannot be used.
204 """
205
206 def get_metadata_routing(self):
207 """Raise `NotImplementedError`.
208
209 This estimator does not support metadata routing yet."""
210 raise NotImplementedError(
211 f"{self.__class__.__name__} has not implemented metadata routing yet."
212 )
213
214
215# Request values
216# ==============
217# Each request value needs to be one of the following values, or an alias.
218
219# this is used in `__metadata_request__*` attributes to indicate that a
220# metadata is not present even though it may be present in the
221# corresponding method's signature.
222UNUSED = "$UNUSED$"
223
224# this is used whenever a default value is changed, and therefore the user
225# should explicitly set the value, otherwise a warning is shown. An example
226# is when a meta-estimator is only a router, but then becomes also a
227# consumer in a new release.
228WARN = "$WARN$"
229
230# this is the default used in `set_{method}_request` methods to indicate no
231# change requested by the user.
232UNCHANGED = "$UNCHANGED$"
233
234VALID_REQUEST_VALUES = [False, True, None, UNUSED, WARN]
235
236
237def request_is_alias(item):
238 """Check if an item is a valid alias.
239
240 Values in ``VALID_REQUEST_VALUES`` are not considered aliases in this
241 context. Only a string which is a valid identifier is.
242
243 Parameters
244 ----------
245 item : object
246 The given item to be checked if it can be an alias.
247
248 Returns
249 -------
250 result : bool
251 Whether the given item is a valid alias.
252 """
253 if item in VALID_REQUEST_VALUES:
254 return False
255
256 # item is only an alias if it's a valid identifier
257 return isinstance(item, str) and item.isidentifier()
258
259
260def request_is_valid(item):
261 """Check if an item is a valid request value (and not an alias).
262
263 Parameters
264 ----------
265 item : object
266 The given item to be checked.
267
268 Returns
269 -------
270 result : bool
271 Whether the given item is valid.
272 """
273 return item in VALID_REQUEST_VALUES
274
275
276# Metadata Request for Simple Consumers
277# =====================================
278# This section includes MethodMetadataRequest and MetadataRequest which are
279# used in simple consumers.
280
281
282class MethodMetadataRequest:
283 """A prescription of how metadata is to be passed to a single method.
284
285 Refer to :class:`MetadataRequest` for how this class is used.
286
287 .. versionadded:: 1.3
288
289 Parameters
290 ----------
291 owner : str
292 A display name for the object owning these requests.
293
294 method : str
295 The name of the method to which these requests belong.
296
297 requests : dict of {str: bool, None or str}, default=None
298 The initial requests for this method.
299 """
300
301 def __init__(self, owner, method, requests=None):
302 self._requests = requests or dict()
303 self.owner = owner
304 self.method = method
305
306 @property
307 def requests(self):
308 """Dictionary of the form: ``{key: alias}``."""
309 return self._requests
310
311 def add_request(
312 self,
313 *,
314 param,
315 alias,
316 ):
317 """Add request info for a metadata.
318
319 Parameters
320 ----------
321 param : str
322 The property for which a request is set.
323
324 alias : str, or {True, False, None}
325 Specifies which metadata should be routed to `param`
326
327 - str: the name (or alias) of metadata given to a meta-estimator that
328 should be routed to this parameter.
329
330 - True: requested
331
332 - False: not requested
333
334 - None: error if passed
335 """
336 if not request_is_alias(alias) and not request_is_valid(alias):
337 raise ValueError(
338 f"The alias you're setting for `{param}` should be either a "
339 "valid identifier or one of {None, True, False}, but given "
340 f"value is: `{alias}`"
341 )
342
343 if alias == param:
344 alias = True
345
346 if alias == UNUSED:
347 if param in self._requests:
348 del self._requests[param]
349 else:
350 raise ValueError(
351 f"Trying to remove parameter {param} with UNUSED which doesn't"
352 " exist."
353 )
354 else:
355 self._requests[param] = alias
356
357 return self
358
359 def _get_param_names(self, return_alias):
360 """Get names of all metadata that can be consumed or routed by this method.
361
362 This method returns the names of all metadata, even the ``False``
363 ones.
364
365 Parameters
366 ----------
367 return_alias : bool
368 Controls whether original or aliased names should be returned. If
369 ``False``, aliases are ignored and original names are returned.
370
371 Returns
372 -------
373 names : set of str
374 A set of strings with the names of all parameters.
375 """
376 return set(
377 alias if return_alias and not request_is_valid(alias) else prop
378 for prop, alias in self._requests.items()
379 if not request_is_valid(alias) or alias is not False
380 )
381
382 def _check_warnings(self, *, params):
383 """Check whether metadata is passed which is marked as WARN.
384
385 If any metadata is passed which is marked as WARN, a warning is raised.
386
387 Parameters
388 ----------
389 params : dict
390 The metadata passed to a method.
391 """
392 params = {} if params is None else params
393 warn_params = {
394 prop
395 for prop, alias in self._requests.items()
396 if alias == WARN and prop in params
397 }
398 for param in warn_params:
399 warn(
400 f"Support for {param} has recently been added to this class. "
401 "To maintain backward compatibility, it is ignored now. "
402 "You can set the request value to False to silence this "
403 "warning, or to True to consume and use the metadata."
404 )
405
406 def _route_params(self, params):
407 """Prepare the given parameters to be passed to the method.
408
409 The output of this method can be used directly as the input to the
410 corresponding method as extra props.
411
412 Parameters
413 ----------
414 params : dict
415 A dictionary of provided metadata.
416
417 Returns
418 -------
419 params : Bunch
420 A :class:`~sklearn.utils.Bunch` of {prop: value} which can be given to the
421 corresponding method.
422 """
423 self._check_warnings(params=params)
424 unrequested = dict()
425 args = {arg: value for arg, value in params.items() if value is not None}
426 res = Bunch()
427 for prop, alias in self._requests.items():
428 if alias is False or alias == WARN:
429 continue
430 elif alias is True and prop in args:
431 res[prop] = args[prop]
432 elif alias is None and prop in args:
433 unrequested[prop] = args[prop]
434 elif alias in args:
435 res[prop] = args[alias]
436 if unrequested:
437 raise UnsetMetadataPassedError(
438 message=(
439 f"[{', '.join([key for key in unrequested])}] are passed but are"
440 " not explicitly set as requested or not for"
441 f" {self.owner}.{self.method}"
442 ),
443 unrequested_params=unrequested,
444 routed_params=res,
445 )
446 return res
447
448 def _consumes(self, params):
449 """Check whether the given parameters are consumed by this method.
450
451 Parameters
452 ----------
453 params : iterable of str
454 An iterable of parameters to check.
455
456 Returns
457 -------
458 consumed : set of str
459 A set of parameters which are consumed by this method.
460 """
461 params = set(params)
462 res = set()
463 for prop, alias in self._requests.items():
464 if alias is True and prop in params:
465 res.add(prop)
466 elif isinstance(alias, str) and alias in params:
467 res.add(alias)
468 return res
469
470 def _serialize(self):
471 """Serialize the object.
472
473 Returns
474 -------
475 obj : dict
476 A serialized version of the instance in the form of a dictionary.
477 """
478 return self._requests
479
480 def __repr__(self):
481 return str(self._serialize())
482
483 def __str__(self):
484 return str(repr(self))
485
486
487class MetadataRequest:
488 """Contains the metadata request info of a consumer.
489
490 Instances of `MethodMetadataRequest` are used in this class for each
491 available method under `metadatarequest.{method}`.
492
493 Consumer-only classes such as simple estimators return a serialized
494 version of this class as the output of `get_metadata_routing()`.
495
496 .. versionadded:: 1.3
497
498 Parameters
499 ----------
500 owner : str
501 The name of the object to which these requests belong.
502 """
503
504 # this is here for us to use this attribute's value instead of doing
505 # `isinstance` in our checks, so that we avoid issues when people vendor
506 # this file instead of using it directly from scikit-learn.
507 _type = "metadata_request"
508
509 def __init__(self, owner):
510 self.owner = owner
511 for method in SIMPLE_METHODS:
512 setattr(
513 self,
514 method,
515 MethodMetadataRequest(owner=owner, method=method),
516 )
517
518 def consumes(self, method, params):
519 """Check whether the given parameters are consumed by the given method.
520
521 .. versionadded:: 1.4
522
523 Parameters
524 ----------
525 method : str
526 The name of the method to check.
527
528 params : iterable of str
529 An iterable of parameters to check.
530
531 Returns
532 -------
533 consumed : set of str
534 A set of parameters which are consumed by the given method.
535 """
536 return getattr(self, method)._consumes(params=params)
537
538 def __getattr__(self, name):
539 # Called when the default attribute access fails with an AttributeError
540 # (either __getattribute__() raises an AttributeError because name is
541 # not an instance attribute or an attribute in the class tree for self;
542 # or __get__() of a name property raises AttributeError). This method
543 # should either return the (computed) attribute value or raise an
544 # AttributeError exception.
545 # https://docs.python.org/3/reference/datamodel.html#object.__getattr__
546 if name not in COMPOSITE_METHODS:
547 raise AttributeError(
548 f"'{self.__class__.__name__}' object has no attribute '{name}'"
549 )
550
551 requests = {}
552 for method in COMPOSITE_METHODS[name]:
553 mmr = getattr(self, method)
554 existing = set(requests.keys())
555 upcoming = set(mmr.requests.keys())
556 common = existing & upcoming
557 conflicts = [key for key in common if requests[key] != mmr._requests[key]]
558 if conflicts:
559 raise ValueError(
560 f"Conflicting metadata requests for {', '.join(conflicts)} while"
561 f" composing the requests for {name}. Metadata with the same name"
562 f" for methods {', '.join(COMPOSITE_METHODS[name])} should have the"
563 " same request value."
564 )
565 requests.update(mmr._requests)
566 return MethodMetadataRequest(owner=self.owner, method=name, requests=requests)
567
568 def _get_param_names(self, method, return_alias, ignore_self_request=None):
569 """Get names of all metadata that can be consumed or routed by specified \
570 method.
571
572 This method returns the names of all metadata, even the ``False``
573 ones.
574
575 Parameters
576 ----------
577 method : str
578 The name of the method for which metadata names are requested.
579
580 return_alias : bool
581 Controls whether original or aliased names should be returned. If
582 ``False``, aliases are ignored and original names are returned.
583
584 ignore_self_request : bool
585 Ignored. Present for API compatibility.
586
587 Returns
588 -------
589 names : set of str
590 A set of strings with the names of all parameters.
591 """
592 return getattr(self, method)._get_param_names(return_alias=return_alias)
593
594 def _route_params(self, *, method, params):
595 """Prepare the given parameters to be passed to the method.
596
597 The output of this method can be used directly as the input to the
598 corresponding method as extra keyword arguments to pass metadata.
599
600 Parameters
601 ----------
602 method : str
603 The name of the method for which the parameters are requested and
604 routed.
605
606 params : dict
607 A dictionary of provided metadata.
608
609 Returns
610 -------
611 params : Bunch
612 A :class:`~sklearn.utils.Bunch` of {prop: value} which can be given to the
613 corresponding method.
614 """
615 return getattr(self, method)._route_params(params=params)
616
617 def _check_warnings(self, *, method, params):
618 """Check whether metadata is passed which is marked as WARN.
619
620 If any metadata is passed which is marked as WARN, a warning is raised.
621
622 Parameters
623 ----------
624 method : str
625 The name of the method for which the warnings should be checked.
626
627 params : dict
628 The metadata passed to a method.
629 """
630 getattr(self, method)._check_warnings(params=params)
631
632 def _serialize(self):
633 """Serialize the object.
634
635 Returns
636 -------
637 obj : dict
638 A serialized version of the instance in the form of a dictionary.
639 """
640 output = dict()
641 for method in SIMPLE_METHODS:
642 mmr = getattr(self, method)
643 if len(mmr.requests):
644 output[method] = mmr._serialize()
645 return output
646
647 def __repr__(self):
648 return str(self._serialize())
649
650 def __str__(self):
651 return str(repr(self))
652
653
654# Metadata Request for Routers
655# ============================
656# This section includes all objects required for MetadataRouter which is used
657# in routers, returned by their ``get_metadata_routing``.
658
659# This namedtuple is used to store a (mapping, routing) pair. Mapping is a
660# MethodMapping object, and routing is the output of `get_metadata_routing`.
661# MetadataRouter stores a collection of these namedtuples.
662RouterMappingPair = namedtuple("RouterMappingPair", ["mapping", "router"])
663
664# A namedtuple storing a single method route. A collection of these namedtuples
665# is stored in a MetadataRouter.
666MethodPair = namedtuple("MethodPair", ["callee", "caller"])
667
668
669class MethodMapping:
670 """Stores the mapping between callee and caller methods for a router.
671
672 This class is primarily used in a ``get_metadata_routing()`` of a router
673 object when defining the mapping between a sub-object (a sub-estimator or a
674 scorer) to the router's methods. It stores a collection of ``Route``
675 namedtuples.
676
677 Iterating through an instance of this class will yield named
678 ``MethodPair(callee, caller)`` tuples.
679
680 .. versionadded:: 1.3
681 """
682
683 def __init__(self):
684 self._routes = []
685
686 def __iter__(self):
687 return iter(self._routes)
688
689 def add(self, *, callee, caller):
690 """Add a method mapping.
691
692 Parameters
693 ----------
694 callee : str
695 Child object's method name. This method is called in ``caller``.
696
697 caller : str
698 Parent estimator's method name in which the ``callee`` is called.
699
700 Returns
701 -------
702 self : MethodMapping
703 Returns self.
704 """
705 if callee not in METHODS:
706 raise ValueError(
707 f"Given callee:{callee} is not a valid method. Valid methods are:"
708 f" {METHODS}"
709 )
710 if caller not in METHODS:
711 raise ValueError(
712 f"Given caller:{caller} is not a valid method. Valid methods are:"
713 f" {METHODS}"
714 )
715 self._routes.append(MethodPair(callee=callee, caller=caller))
716 return self
717
718 def _serialize(self):
719 """Serialize the object.
720
721 Returns
722 -------
723 obj : list
724 A serialized version of the instance in the form of a list.
725 """
726 result = list()
727 for route in self._routes:
728 result.append({"callee": route.callee, "caller": route.caller})
729 return result
730
731 @classmethod
732 def from_str(cls, route):
733 """Construct an instance from a string.
734
735 Parameters
736 ----------
737 route : str
738 A string representing the mapping, it can be:
739
740 - `"one-to-one"`: a one to one mapping for all methods.
741 - `"method"`: the name of a single method, such as ``fit``,
742 ``transform``, ``score``, etc.
743
744 Returns
745 -------
746 obj : MethodMapping
747 A :class:`~sklearn.utils.metadata_routing.MethodMapping` instance
748 constructed from the given string.
749 """
750 routing = cls()
751 if route == "one-to-one":
752 for method in METHODS:
753 routing.add(callee=method, caller=method)
754 elif route in METHODS:
755 routing.add(callee=route, caller=route)
756 else:
757 raise ValueError("route should be 'one-to-one' or a single method!")
758 return routing
759
760 def __repr__(self):
761 return str(self._serialize())
762
763 def __str__(self):
764 return str(repr(self))
765
766
767class MetadataRouter:
768 """Stores and handles metadata routing for a router object.
769
770 This class is used by router objects to store and handle metadata routing.
771 Routing information is stored as a dictionary of the form ``{"object_name":
772 RouteMappingPair(method_mapping, routing_info)}``, where ``method_mapping``
773 is an instance of :class:`~sklearn.utils.metadata_routing.MethodMapping` and
774 ``routing_info`` is either a
775 :class:`~sklearn.utils.metadata_routing.MetadataRequest` or a
776 :class:`~sklearn.utils.metadata_routing.MetadataRouter` instance.
777
778 .. versionadded:: 1.3
779
780 Parameters
781 ----------
782 owner : str
783 The name of the object to which these requests belong.
784 """
785
786 # this is here for us to use this attribute's value instead of doing
787 # `isinstance`` in our checks, so that we avoid issues when people vendor
788 # this file instead of using it directly from scikit-learn.
789 _type = "metadata_router"
790
791 def __init__(self, owner):
792 self._route_mappings = dict()
793 # `_self_request` is used if the router is also a consumer.
794 # _self_request, (added using `add_self_request()`) is treated
795 # differently from the other objects which are stored in
796 # _route_mappings.
797 self._self_request = None
798 self.owner = owner
799
800 def add_self_request(self, obj):
801 """Add `self` (as a consumer) to the routing.
802
803 This method is used if the router is also a consumer, and hence the
804 router itself needs to be included in the routing. The passed object
805 can be an estimator or a
806 :class:`~sklearn.utils.metadata_routing.MetadataRequest`.
807
808 A router should add itself using this method instead of `add` since it
809 should be treated differently than the other objects to which metadata
810 is routed by the router.
811
812 Parameters
813 ----------
814 obj : object
815 This is typically the router instance, i.e. `self` in a
816 ``get_metadata_routing()`` implementation. It can also be a
817 ``MetadataRequest`` instance.
818
819 Returns
820 -------
821 self : MetadataRouter
822 Returns `self`.
823 """
824 if getattr(obj, "_type", None) == "metadata_request":
825 self._self_request = deepcopy(obj)
826 elif hasattr(obj, "_get_metadata_request"):
827 self._self_request = deepcopy(obj._get_metadata_request())
828 else:
829 raise ValueError(
830 "Given `obj` is neither a `MetadataRequest` nor does it implement the"
831 " required API. Inheriting from `BaseEstimator` implements the required"
832 " API."
833 )
834 return self
835
836 def add(self, *, method_mapping, **objs):
837 """Add named objects with their corresponding method mapping.
838
839 Parameters
840 ----------
841 method_mapping : MethodMapping or str
842 The mapping between the child and the parent's methods. If str, the
843 output of :func:`~sklearn.utils.metadata_routing.MethodMapping.from_str`
844 is used.
845
846 **objs : dict
847 A dictionary of objects from which metadata is extracted by calling
848 :func:`~sklearn.utils.metadata_routing.get_routing_for_object` on them.
849
850 Returns
851 -------
852 self : MetadataRouter
853 Returns `self`.
854 """
855 if isinstance(method_mapping, str):
856 method_mapping = MethodMapping.from_str(method_mapping)
857 else:
858 method_mapping = deepcopy(method_mapping)
859
860 for name, obj in objs.items():
861 self._route_mappings[name] = RouterMappingPair(
862 mapping=method_mapping, router=get_routing_for_object(obj)
863 )
864 return self
865
866 def consumes(self, method, params):
867 """Check whether the given parameters are consumed by the given method.
868
869 .. versionadded:: 1.4
870
871 Parameters
872 ----------
873 method : str
874 The name of the method to check.
875
876 params : iterable of str
877 An iterable of parameters to check.
878
879 Returns
880 -------
881 consumed : set of str
882 A set of parameters which are consumed by the given method.
883 """
884 res = set()
885 if self._self_request:
886 res = res | self._self_request.consumes(method=method, params=params)
887
888 for _, route_mapping in self._route_mappings.items():
889 for callee, caller in route_mapping.mapping:
890 if caller == method:
891 res = res | route_mapping.router.consumes(
892 method=callee, params=params
893 )
894
895 return res
896
897 def _get_param_names(self, *, method, return_alias, ignore_self_request):
898 """Get names of all metadata that can be consumed or routed by specified \
899 method.
900
901 This method returns the names of all metadata, even the ``False``
902 ones.
903
904 Parameters
905 ----------
906 method : str
907 The name of the method for which metadata names are requested.
908
909 return_alias : bool
910 Controls whether original or aliased names should be returned,
911 which only applies to the stored `self`. If no `self` routing
912 object is stored, this parameter has no effect.
913
914 ignore_self_request : bool
915 If `self._self_request` should be ignored. This is used in `_route_params`.
916 If ``True``, ``return_alias`` has no effect.
917
918 Returns
919 -------
920 names : set of str
921 A set of strings with the names of all parameters.
922 """
923 res = set()
924 if self._self_request and not ignore_self_request:
925 res = res.union(
926 self._self_request._get_param_names(
927 method=method, return_alias=return_alias
928 )
929 )
930
931 for name, route_mapping in self._route_mappings.items():
932 for callee, caller in route_mapping.mapping:
933 if caller == method:
934 res = res.union(
935 route_mapping.router._get_param_names(
936 method=callee, return_alias=True, ignore_self_request=False
937 )
938 )
939 return res
940
941 def _route_params(self, *, params, method):
942 """Prepare the given parameters to be passed to the method.
943
944 This is used when a router is used as a child object of another router.
945 The parent router then passes all parameters understood by the child
946 object to it and delegates their validation to the child.
947
948 The output of this method can be used directly as the input to the
949 corresponding method as extra props.
950
951 Parameters
952 ----------
953 method : str
954 The name of the method for which the parameters are requested and
955 routed.
956
957 params : dict
958 A dictionary of provided metadata.
959
960 Returns
961 -------
962 params : Bunch
963 A :class:`~sklearn.utils.Bunch` of {prop: value} which can be given to the
964 corresponding method.
965 """
966 res = Bunch()
967 if self._self_request:
968 res.update(self._self_request._route_params(params=params, method=method))
969
970 param_names = self._get_param_names(
971 method=method, return_alias=True, ignore_self_request=True
972 )
973 child_params = {
974 key: value for key, value in params.items() if key in param_names
975 }
976 for key in set(res.keys()).intersection(child_params.keys()):
977 # conflicts are okay if the passed objects are the same, but it's
978 # an issue if they're different objects.
979 if child_params[key] is not res[key]:
980 raise ValueError(
981 f"In {self.owner}, there is a conflict on {key} between what is"
982 " requested for this estimator and what is requested by its"
983 " children. You can resolve this conflict by using an alias for"
984 " the child estimator(s) requested metadata."
985 )
986
987 res.update(child_params)
988 return res
989
990 def route_params(self, *, caller, params):
991 """Return the input parameters requested by child objects.
992
993 The output of this method is a bunch, which includes the inputs for all
994 methods of each child object that are used in the router's `caller`
995 method.
996
997 If the router is also a consumer, it also checks for warnings of
998 `self`'s/consumer's requested metadata.
999
1000 Parameters
1001 ----------
1002 caller : str
1003 The name of the method for which the parameters are requested and
1004 routed. If called inside the :term:`fit` method of a router, it
1005 would be `"fit"`.
1006
1007 params : dict
1008 A dictionary of provided metadata.
1009
1010 Returns
1011 -------
1012 params : Bunch
1013 A :class:`~sklearn.utils.Bunch` of the form
1014 ``{"object_name": {"method_name": {prop: value}}}`` which can be
1015 used to pass the required metadata to corresponding methods or
1016 corresponding child objects.
1017 """
1018 if self._self_request:
1019 self._self_request._check_warnings(params=params, method=caller)
1020
1021 res = Bunch()
1022 for name, route_mapping in self._route_mappings.items():
1023 router, mapping = route_mapping.router, route_mapping.mapping
1024
1025 res[name] = Bunch()
1026 for _callee, _caller in mapping:
1027 if _caller == caller:
1028 res[name][_callee] = router._route_params(
1029 params=params, method=_callee
1030 )
1031 return res
1032
1033 def validate_metadata(self, *, method, params):
1034 """Validate given metadata for a method.
1035
1036 This raises a ``TypeError`` if some of the passed metadata are not
1037 understood by child objects.
1038
1039 Parameters
1040 ----------
1041 method : str
1042 The name of the method for which the parameters are requested and
1043 routed. If called inside the :term:`fit` method of a router, it
1044 would be `"fit"`.
1045
1046 params : dict
1047 A dictionary of provided metadata.
1048 """
1049 param_names = self._get_param_names(
1050 method=method, return_alias=False, ignore_self_request=False
1051 )
1052 if self._self_request:
1053 self_params = self._self_request._get_param_names(
1054 method=method, return_alias=False
1055 )
1056 else:
1057 self_params = set()
1058 extra_keys = set(params.keys()) - param_names - self_params
1059 if extra_keys:
1060 raise TypeError(
1061 f"{self.owner}.{method} got unexpected argument(s) {extra_keys}, which"
1062 " are not requested metadata in any object."
1063 )
1064
1065 def _serialize(self):
1066 """Serialize the object.
1067
1068 Returns
1069 -------
1070 obj : dict
1071 A serialized version of the instance in the form of a dictionary.
1072 """
1073 res = dict()
1074 if self._self_request:
1075 res["$self_request"] = self._self_request._serialize()
1076 for name, route_mapping in self._route_mappings.items():
1077 res[name] = dict()
1078 res[name]["mapping"] = route_mapping.mapping._serialize()
1079 res[name]["router"] = route_mapping.router._serialize()
1080
1081 return res
1082
1083 def __iter__(self):
1084 if self._self_request:
1085 yield "$self_request", RouterMappingPair(
1086 mapping=MethodMapping.from_str("one-to-one"), router=self._self_request
1087 )
1088 for name, route_mapping in self._route_mappings.items():
1089 yield (name, route_mapping)
1090
1091 def __repr__(self):
1092 return str(self._serialize())
1093
1094 def __str__(self):
1095 return str(repr(self))
1096
1097
1098def get_routing_for_object(obj=None):
1099 """Get a ``Metadata{Router, Request}`` instance from the given object.
1100
1101 This function returns a
1102 :class:`~sklearn.utils.metadata_routing.MetadataRouter` or a
1103 :class:`~sklearn.utils.metadata_routing.MetadataRequest` from the given input.
1104
1105 This function always returns a copy or an instance constructed from the
1106 input, such that changing the output of this function will not change the
1107 original object.
1108
1109 .. versionadded:: 1.3
1110
1111 Parameters
1112 ----------
1113 obj : object
1114 - If the object is already a
1115 :class:`~sklearn.utils.metadata_routing.MetadataRequest` or a
1116 :class:`~sklearn.utils.metadata_routing.MetadataRouter`, return a copy
1117 of that.
1118 - If the object provides a `get_metadata_routing` method, return a copy
1119 of the output of that method.
1120 - Returns an empty :class:`~sklearn.utils.metadata_routing.MetadataRequest`
1121 otherwise.
1122
1123 Returns
1124 -------
1125 obj : MetadataRequest or MetadataRouting
1126 A ``MetadataRequest`` or a ``MetadataRouting`` taken or created from
1127 the given object.
1128 """
1129 # doing this instead of a try/except since an AttributeError could be raised
1130 # for other reasons.
1131 if hasattr(obj, "get_metadata_routing"):
1132 return deepcopy(obj.get_metadata_routing())
1133
1134 elif getattr(obj, "_type", None) in ["metadata_request", "metadata_router"]:
1135 return deepcopy(obj)
1136
1137 return MetadataRequest(owner=None)
1138
1139
1140# Request method
1141# ==============
1142# This section includes what's needed for the request method descriptor and
1143# their dynamic generation in a meta class.
1144
1145# These strings are used to dynamically generate the docstrings for
1146# set_{method}_request methods.
1147REQUESTER_DOC = """ Request metadata passed to the ``{method}`` method.
1148
1149 Note that this method is only relevant if
1150 ``enable_metadata_routing=True`` (see :func:`sklearn.set_config`).
1151 Please see :ref:`User Guide <metadata_routing>` on how the routing
1152 mechanism works.
1153
1154 The options for each parameter are:
1155
1156 - ``True``: metadata is requested, and \
1157passed to ``{method}`` if provided. The request is ignored if \
1158metadata is not provided.
1159
1160 - ``False``: metadata is not requested and the meta-estimator \
1161will not pass it to ``{method}``.
1162
1163 - ``None``: metadata is not requested, and the meta-estimator \
1164will raise an error if the user provides it.
1165
1166 - ``str``: metadata should be passed to the meta-estimator with \
1167this given alias instead of the original name.
1168
1169 The default (``sklearn.utils.metadata_routing.UNCHANGED``) retains the
1170 existing request. This allows you to change the request for some
1171 parameters and not others.
1172
1173 .. versionadded:: 1.3
1174
1175 .. note::
1176 This method is only relevant if this estimator is used as a
1177 sub-estimator of a meta-estimator, e.g. used inside a
1178 :class:`~sklearn.pipeline.Pipeline`. Otherwise it has no effect.
1179
1180 Parameters
1181 ----------
1182"""
1183REQUESTER_DOC_PARAM = """ {metadata} : str, True, False, or None, \
1184 default=sklearn.utils.metadata_routing.UNCHANGED
1185 Metadata routing for ``{metadata}`` parameter in ``{method}``.
1186
1187"""
1188REQUESTER_DOC_RETURN = """ Returns
1189 -------
1190 self : object
1191 The updated object.
1192"""
1193
1194
1195class RequestMethod:
1196 """
1197 A descriptor for request methods.
1198
1199 .. versionadded:: 1.3
1200
1201 Parameters
1202 ----------
1203 name : str
1204 The name of the method for which the request function should be
1205 created, e.g. ``"fit"`` would create a ``set_fit_request`` function.
1206
1207 keys : list of str
1208 A list of strings which are accepted parameters by the created
1209 function, e.g. ``["sample_weight"]`` if the corresponding method
1210 accepts it as a metadata.
1211
1212 validate_keys : bool, default=True
1213 Whether to check if the requested parameters fit the actual parameters
1214 of the method.
1215
1216 Notes
1217 -----
1218 This class is a descriptor [1]_ and uses PEP-362 to set the signature of
1219 the returned function [2]_.
1220
1221 References
1222 ----------
1223 .. [1] https://docs.python.org/3/howto/descriptor.html
1224
1225 .. [2] https://www.python.org/dev/peps/pep-0362/
1226 """
1227
1228 def __init__(self, name, keys, validate_keys=True):
1229 self.name = name
1230 self.keys = keys
1231 self.validate_keys = validate_keys
1232
1233 def __get__(self, instance, owner):
1234 # we would want to have a method which accepts only the expected args
1235 def func(**kw):
1236 """Updates the request for provided parameters
1237
1238 This docstring is overwritten below.
1239 See REQUESTER_DOC for expected functionality
1240 """
1241 if not _routing_enabled():
1242 raise RuntimeError(
1243 "This method is only available when metadata routing is enabled."
1244 " You can enable it using"
1245 " sklearn.set_config(enable_metadata_routing=True)."
1246 )
1247
1248 if self.validate_keys and (set(kw) - set(self.keys)):
1249 raise TypeError(
1250 f"Unexpected args: {set(kw) - set(self.keys)}. Accepted arguments"
1251 f" are: {set(self.keys)}"
1252 )
1253
1254 requests = instance._get_metadata_request()
1255 method_metadata_request = getattr(requests, self.name)
1256
1257 for prop, alias in kw.items():
1258 if alias is not UNCHANGED:
1259 method_metadata_request.add_request(param=prop, alias=alias)
1260 instance._metadata_request = requests
1261
1262 return instance
1263
1264 # Now we set the relevant attributes of the function so that it seems
1265 # like a normal method to the end user, with known expected arguments.
1266 func.__name__ = f"set_{self.name}_request"
1267 params = [
1268 inspect.Parameter(
1269 name="self",
1270 kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
1271 annotation=owner,
1272 )
1273 ]
1274 params.extend(
1275 [
1276 inspect.Parameter(
1277 k,
1278 inspect.Parameter.KEYWORD_ONLY,
1279 default=UNCHANGED,
1280 annotation=Optional[Union[bool, None, str]],
1281 )
1282 for k in self.keys
1283 ]
1284 )
1285 func.__signature__ = inspect.Signature(
1286 params,
1287 return_annotation=owner,
1288 )
1289 doc = REQUESTER_DOC.format(method=self.name)
1290 for metadata in self.keys:
1291 doc += REQUESTER_DOC_PARAM.format(metadata=metadata, method=self.name)
1292 doc += REQUESTER_DOC_RETURN
1293 func.__doc__ = doc
1294 return func
1295
1296
1297class _MetadataRequester:
1298 """Mixin class for adding metadata request functionality.
1299
1300 ``BaseEstimator`` inherits from this Mixin.
1301
1302 .. versionadded:: 1.3
1303 """
1304
1305 if TYPE_CHECKING: # pragma: no cover
1306 # This code is never run in runtime, but it's here for type checking.
1307 # Type checkers fail to understand that the `set_{method}_request`
1308 # methods are dynamically generated, and they complain that they are
1309 # not defined. We define them here to make type checkers happy.
1310 # During type checking analyzers assume this to be True.
1311 # The following list of defined methods mirrors the list of methods
1312 # in SIMPLE_METHODS.
1313 # fmt: off
1314 def set_fit_request(self, **kwargs): pass
1315 def set_partial_fit_request(self, **kwargs): pass
1316 def set_predict_request(self, **kwargs): pass
1317 def set_predict_proba_request(self, **kwargs): pass
1318 def set_predict_log_proba_request(self, **kwargs): pass
1319 def set_decision_function_request(self, **kwargs): pass
1320 def set_score_request(self, **kwargs): pass
1321 def set_split_request(self, **kwargs): pass
1322 def set_transform_request(self, **kwargs): pass
1323 def set_inverse_transform_request(self, **kwargs): pass
1324 # fmt: on
1325
1326 def __init_subclass__(cls, **kwargs):
1327 """Set the ``set_{method}_request`` methods.
1328
1329 This uses PEP-487 [1]_ to set the ``set_{method}_request`` methods. It
1330 looks for the information available in the set default values which are
1331 set using ``__metadata_request__*`` class attributes, or inferred
1332 from method signatures.
1333
1334 The ``__metadata_request__*`` class attributes are used when a method
1335 does not explicitly accept a metadata through its arguments or if the
1336 developer would like to specify a request value for those metadata
1337 which are different from the default ``None``.
1338
1339 References
1340 ----------
1341 .. [1] https://www.python.org/dev/peps/pep-0487
1342 """
1343 try:
1344 requests = cls._get_default_requests()
1345 except Exception:
1346 # if there are any issues in the default values, it will be raised
1347 # when ``get_metadata_routing`` is called. Here we are going to
1348 # ignore all the issues such as bad defaults etc.
1349 super().__init_subclass__(**kwargs)
1350 return
1351
1352 for method in SIMPLE_METHODS:
1353 mmr = getattr(requests, method)
1354 # set ``set_{method}_request``` methods
1355 if not len(mmr.requests):
1356 continue
1357 setattr(
1358 cls,
1359 f"set_{method}_request",
1360 RequestMethod(method, sorted(mmr.requests.keys())),
1361 )
1362 super().__init_subclass__(**kwargs)
1363
1364 @classmethod
1365 def _build_request_for_signature(cls, router, method):
1366 """Build the `MethodMetadataRequest` for a method using its signature.
1367
1368 This method takes all arguments from the method signature and uses
1369 ``None`` as their default request value, except ``X``, ``y``, ``Y``,
1370 ``Xt``, ``yt``, ``*args``, and ``**kwargs``.
1371
1372 Parameters
1373 ----------
1374 router : MetadataRequest
1375 The parent object for the created `MethodMetadataRequest`.
1376 method : str
1377 The name of the method.
1378
1379 Returns
1380 -------
1381 method_request : MethodMetadataRequest
1382 The prepared request using the method's signature.
1383 """
1384 mmr = MethodMetadataRequest(owner=cls.__name__, method=method)
1385 # Here we use `isfunction` instead of `ismethod` because calling `getattr`
1386 # on a class instead of an instance returns an unbound function.
1387 if not hasattr(cls, method) or not inspect.isfunction(getattr(cls, method)):
1388 return mmr
1389 # ignore the first parameter of the method, which is usually "self"
1390 params = list(inspect.signature(getattr(cls, method)).parameters.items())[1:]
1391 for pname, param in params:
1392 if pname in {"X", "y", "Y", "Xt", "yt"}:
1393 continue
1394 if param.kind in {param.VAR_POSITIONAL, param.VAR_KEYWORD}:
1395 continue
1396 mmr.add_request(
1397 param=pname,
1398 alias=None,
1399 )
1400 return mmr
1401
1402 @classmethod
1403 def _get_default_requests(cls):
1404 """Collect default request values.
1405
1406 This method combines the information present in ``__metadata_request__*``
1407 class attributes, as well as determining request keys from method
1408 signatures.
1409 """
1410 requests = MetadataRequest(owner=cls.__name__)
1411
1412 for method in SIMPLE_METHODS:
1413 setattr(
1414 requests,
1415 method,
1416 cls._build_request_for_signature(router=requests, method=method),
1417 )
1418
1419 # Then overwrite those defaults with the ones provided in
1420 # __metadata_request__* attributes. Defaults set in
1421 # __metadata_request__* attributes take precedence over signature
1422 # sniffing.
1423
1424 # need to go through the MRO since this is a class attribute and
1425 # ``vars`` doesn't report the parent class attributes. We go through
1426 # the reverse of the MRO so that child classes have precedence over
1427 # their parents.
1428 defaults = dict()
1429 for base_class in reversed(inspect.getmro(cls)):
1430 base_defaults = {
1431 attr: value
1432 for attr, value in vars(base_class).items()
1433 if "__metadata_request__" in attr
1434 }
1435 defaults.update(base_defaults)
1436 defaults = dict(sorted(defaults.items()))
1437
1438 for attr, value in defaults.items():
1439 # we don't check for attr.startswith() since python prefixes attrs
1440 # starting with __ with the `_ClassName`.
1441 substr = "__metadata_request__"
1442 method = attr[attr.index(substr) + len(substr) :]
1443 for prop, alias in value.items():
1444 getattr(requests, method).add_request(param=prop, alias=alias)
1445
1446 return requests
1447
1448 def _get_metadata_request(self):
1449 """Get requested data properties.
1450
1451 Please check :ref:`User Guide <metadata_routing>` on how the routing
1452 mechanism works.
1453
1454 Returns
1455 -------
1456 request : MetadataRequest
1457 A :class:`~sklearn.utils.metadata_routing.MetadataRequest` instance.
1458 """
1459 if hasattr(self, "_metadata_request"):
1460 requests = get_routing_for_object(self._metadata_request)
1461 else:
1462 requests = self._get_default_requests()
1463
1464 return requests
1465
1466 def get_metadata_routing(self):
1467 """Get metadata routing of this object.
1468
1469 Please check :ref:`User Guide <metadata_routing>` on how the routing
1470 mechanism works.
1471
1472 Returns
1473 -------
1474 routing : MetadataRequest
1475 A :class:`~sklearn.utils.metadata_routing.MetadataRequest` encapsulating
1476 routing information.
1477 """
1478 return self._get_metadata_request()
1479
1480
1481# Process Routing in Routers
1482# ==========================
1483# This is almost always the only method used in routers to process and route
1484# given metadata. This is to minimize the boilerplate required in routers.
1485
1486
1487# Here the first two arguments are positional only which makes everything
1488# passed as keyword argument a metadata. The first two args also have an `_`
1489# prefix to reduce the chances of name collisions with the passed metadata, and
1490# since they're positional only, users will never type those underscores.
1491def process_routing(_obj, _method, /, **kwargs):
1492 """Validate and route input parameters.
1493
1494 This function is used inside a router's method, e.g. :term:`fit`,
1495 to validate the metadata and handle the routing.
1496
1497 Assuming this signature: ``fit(self, X, y, sample_weight=None, **fit_params)``,
1498 a call to this function would be:
1499 ``process_routing(self, sample_weight=sample_weight, **fit_params)``.
1500
1501 Note that if routing is not enabled and ``kwargs`` is empty, then it
1502 returns an empty routing where ``process_routing(...).ANYTHING.ANY_METHOD``
1503 is always an empty dictionary.
1504
1505 .. versionadded:: 1.3
1506
1507 Parameters
1508 ----------
1509 _obj : object
1510 An object implementing ``get_metadata_routing``. Typically a
1511 meta-estimator.
1512
1513 _method : str
1514 The name of the router's method in which this function is called.
1515
1516 **kwargs : dict
1517 Metadata to be routed.
1518
1519 Returns
1520 -------
1521 routed_params : Bunch
1522 A :class:`~sklearn.utils.Bunch` of the form ``{"object_name": {"method_name":
1523 {prop: value}}}`` which can be used to pass the required metadata to
1524 corresponding methods or corresponding child objects. The object names
1525 are those defined in `obj.get_metadata_routing()`.
1526 """
1527 if not _routing_enabled() and not kwargs:
1528 # If routing is not enabled and kwargs are empty, then we don't have to
1529 # try doing any routing, we can simply return a structure which returns
1530 # an empty dict on routed_params.ANYTHING.ANY_METHOD.
1531 class EmptyRequest:
1532 def get(self, name, default=None):
1533 return default if default else {}
1534
1535 def __getitem__(self, name):
1536 return Bunch(**{method: dict() for method in METHODS})
1537
1538 def __getattr__(self, name):
1539 return Bunch(**{method: dict() for method in METHODS})
1540
1541 return EmptyRequest()
1542
1543 if not (hasattr(_obj, "get_metadata_routing") or isinstance(_obj, MetadataRouter)):
1544 raise AttributeError(
1545 f"The given object ({repr(_obj.__class__.__name__)}) needs to either"
1546 " implement the routing method `get_metadata_routing` or be a"
1547 " `MetadataRouter` instance."
1548 )
1549 if _method not in METHODS:
1550 raise TypeError(
1551 f"Can only route and process input on these methods: {METHODS}, "
1552 f"while the passed method is: {_method}."
1553 )
1554
1555 request_routing = get_routing_for_object(_obj)
1556 request_routing.validate_metadata(params=kwargs, method=_method)
1557 routed_params = request_routing.route_params(params=kwargs, caller=_method)
1558
1559 return routed_params