Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scikit_learn-1.4.dev0-py3.8-linux-x86_64.egg/sklearn/utils/_metadata_requests.py: 37%

355 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1""" 

2Metadata Routing Utility 

3 

4In order to better understand the components implemented in this file, one 

5needs to understand their relationship to one another. 

6 

7The only relevant public API for end users are the ``set_{method}_request``, 

8e.g. ``estimator.set_fit_request(sample_weight=True)``. However, third-party 

9developers and users who implement custom meta-estimators, need to deal with 

10the objects implemented in this file. 

11 

12All estimators (should) implement a ``get_metadata_routing`` method, returning 

13the routing requests set for the estimator. This method is automatically 

14implemented via ``BaseEstimator`` for all simple estimators, but needs a custom 

15implementation for meta-estimators. 

16 

17In non-routing consumers, i.e. the simplest case, e.g. ``SVM``, 

18``get_metadata_routing`` returns a ``MetadataRequest`` object. 

19 

20In routers, e.g. meta-estimators and a multi metric scorer, 

21``get_metadata_routing`` returns a ``MetadataRouter`` object. 

22 

23An object which is both a router and a consumer, e.g. a meta-estimator which 

24consumes ``sample_weight`` and routes ``sample_weight`` to its sub-estimators, 

25routing information includes both information about the object itself (added 

26via ``MetadataRouter.add_self_request``), as well as the routing information 

27for its sub-estimators. 

28 

29A ``MetadataRequest`` instance includes one ``MethodMetadataRequest`` per 

30method in ``METHODS``, which includes ``fit``, ``score``, etc. 

31 

32Request values are added to the routing mechanism by adding them to 

33``MethodMetadataRequest`` instances, e.g. 

34``metadatarequest.fit.add(param="sample_weight", alias="my_weights")``. This is 

35used in ``set_{method}_request`` which are automatically generated, so users 

36and developers almost never need to directly call methods on a 

37``MethodMetadataRequest``. 

38 

39The ``alias`` above in the ``add`` method has to be either a string (an alias), 

40or a {True (requested), False (unrequested), None (error if passed)}``. There 

41are some other special values such as ``UNUSED`` and ``WARN`` which are used 

42for purposes such as warning of removing a metadata in a child class, but not 

43used by the end users. 

44 

45``MetadataRouter`` includes information about sub-objects' routing and how 

46methods are mapped together. For instance, the information about which methods 

47of a sub-estimator are called in which methods of the meta-estimator are all 

48stored here. Conceptually, this information looks like: 

49 

50``` 

51{ 

52 "sub_estimator1": ( 

53 mapping=[(caller="fit", callee="transform"), ...], 

54 router=MetadataRequest(...), # or another MetadataRouter 

55 ), 

56 ... 

57} 

58``` 

59 

60To give the above representation some structure, we use the following objects: 

61 

62- ``(caller, callee)`` is a namedtuple called ``MethodPair`` 

63 

64- The list of ``MethodPair`` stored in the ``mapping`` field is a 

65 ``MethodMapping`` object 

66 

67- ``(mapping=..., router=...)`` is a namedtuple called ``RouterMappingPair`` 

68 

69The ``set_{method}_request`` methods are dynamically generated for estimators 

70which inherit from the ``BaseEstimator``. This is done by attaching instances 

71of the ``RequestMethod`` descriptor to classes, which is done in the 

72``_MetadataRequester`` class, and ``BaseEstimator`` inherits from this mixin. 

73This mixin also implements the ``get_metadata_routing``, which meta-estimators 

74need to override, but it works for simple consumers as is. 

75""" 

76 

77# Author: Adrin Jalali <adrin.jalali@gmail.com> 

78# License: BSD 3 clause 

79 

80import inspect 

81from collections import namedtuple 

82from copy import deepcopy 

83from typing import TYPE_CHECKING, Optional, Union 

84from warnings import warn 

85 

86from .. import get_config 

87from ..exceptions import UnsetMetadataPassedError 

88from ._bunch import Bunch 

89 

90# Only the following methods are supported in the routing mechanism. Adding new 

91# methods at the moment involves monkeypatching this list. 

92# Note that if this list is changed or monkeypatched, the corresponding method 

93# needs to be added under a TYPE_CHECKING condition like the one done here in 

94# _MetadataRequester 

95SIMPLE_METHODS = [ 

96 "fit", 

97 "partial_fit", 

98 "predict", 

99 "predict_proba", 

100 "predict_log_proba", 

101 "decision_function", 

102 "score", 

103 "split", 

104 "transform", 

105 "inverse_transform", 

106] 

107 

108# These methods are a composite of other methods and one cannot set their 

109# requests directly. Instead they should be set by setting the requests of the 

110# simple methods which make the composite ones. 

111COMPOSITE_METHODS = { 

112 "fit_transform": ["fit", "transform"], 

113 "fit_predict": ["fit", "predict"], 

114} 

115 

116METHODS = SIMPLE_METHODS + list(COMPOSITE_METHODS.keys()) 

117 

118 

119def _routing_enabled(): 

120 """Return whether metadata routing is enabled. 

121 

122 .. versionadded:: 1.3 

123 

124 Returns 

125 ------- 

126 enabled : bool 

127 Whether metadata routing is enabled. If the config is not set, it 

128 defaults to False. 

129 """ 

130 return get_config().get("enable_metadata_routing", False) 

131 

132 

133def _raise_for_params(params, owner, method): 

134 """Raise an error if metadata routing is not enabled and params are passed. 

135 

136 .. versionadded:: 1.4 

137 

138 Parameters 

139 ---------- 

140 params : dict 

141 The metadata passed to a method. 

142 

143 owner : object 

144 The object to which the method belongs. 

145 

146 method : str 

147 The name of the method, e.g. "fit". 

148 

149 Raises 

150 ------ 

151 ValueError 

152 If metadata routing is not enabled and params are passed. 

153 """ 

154 caller = ( 

155 f"{owner.__class__.__name__}.{method}" if method else owner.__class__.__name__ 

156 ) 

157 if not _routing_enabled() and params: 

158 raise ValueError( 

159 f"Passing extra keyword arguments to {caller} is only supported if" 

160 " enable_metadata_routing=True, which you can set using" 

161 " `sklearn.set_config`. See the User Guide" 

162 " <https://scikit-learn.org/stable/metadata_routing.html> for more" 

163 f" details. Extra parameters passed are: {set(params)}" 

164 ) 

165 

166 

167def _raise_for_unsupported_routing(obj, method, **kwargs): 

168 """Raise when metadata routing is enabled and metadata is passed. 

169 

170 This is used in meta-estimators which have not implemented metadata routing 

171 to prevent silent bugs. There is no need to use this function if the 

172 meta-estimator is not accepting any metadata, especially in `fit`, since 

173 if a meta-estimator accepts any metadata, they would do that in `fit` as 

174 well. 

175 

176 Parameters 

177 ---------- 

178 obj : estimator 

179 The estimator for which we're raising the error. 

180 

181 method : str 

182 The method where the error is raised. 

183 

184 **kwargs : dict 

185 The metadata passed to the method. 

186 """ 

187 kwargs = {key: value for key, value in kwargs.items() if value is not None} 

188 if _routing_enabled() and kwargs: 

189 cls_name = obj.__class__.__name__ 

190 raise NotImplementedError( 

191 f"{cls_name}.{method} cannot accept given metadata ({set(kwargs.keys())})" 

192 f" since metadata routing is not yet implemented for {cls_name}." 

193 ) 

194 

195 

196class _RoutingNotSupportedMixin: 

197 """A mixin to be used to remove the default `get_metadata_routing`. 

198 

199 This is used in meta-estimators where metadata routing is not yet 

200 implemented. 

201 

202 This also makes it clear in our rendered documentation that this method 

203 cannot be used. 

204 """ 

205 

206 def get_metadata_routing(self): 

207 """Raise `NotImplementedError`. 

208 

209 This estimator does not support metadata routing yet.""" 

210 raise NotImplementedError( 

211 f"{self.__class__.__name__} has not implemented metadata routing yet." 

212 ) 

213 

214 

215# Request values 

216# ============== 

217# Each request value needs to be one of the following values, or an alias. 

218 

219# this is used in `__metadata_request__*` attributes to indicate that a 

220# metadata is not present even though it may be present in the 

221# corresponding method's signature. 

222UNUSED = "$UNUSED$" 

223 

224# this is used whenever a default value is changed, and therefore the user 

225# should explicitly set the value, otherwise a warning is shown. An example 

226# is when a meta-estimator is only a router, but then becomes also a 

227# consumer in a new release. 

228WARN = "$WARN$" 

229 

230# this is the default used in `set_{method}_request` methods to indicate no 

231# change requested by the user. 

232UNCHANGED = "$UNCHANGED$" 

233 

234VALID_REQUEST_VALUES = [False, True, None, UNUSED, WARN] 

235 

236 

237def request_is_alias(item): 

238 """Check if an item is a valid alias. 

239 

240 Values in ``VALID_REQUEST_VALUES`` are not considered aliases in this 

241 context. Only a string which is a valid identifier is. 

242 

243 Parameters 

244 ---------- 

245 item : object 

246 The given item to be checked if it can be an alias. 

247 

248 Returns 

249 ------- 

250 result : bool 

251 Whether the given item is a valid alias. 

252 """ 

253 if item in VALID_REQUEST_VALUES: 

254 return False 

255 

256 # item is only an alias if it's a valid identifier 

257 return isinstance(item, str) and item.isidentifier() 

258 

259 

260def request_is_valid(item): 

261 """Check if an item is a valid request value (and not an alias). 

262 

263 Parameters 

264 ---------- 

265 item : object 

266 The given item to be checked. 

267 

268 Returns 

269 ------- 

270 result : bool 

271 Whether the given item is valid. 

272 """ 

273 return item in VALID_REQUEST_VALUES 

274 

275 

276# Metadata Request for Simple Consumers 

277# ===================================== 

278# This section includes MethodMetadataRequest and MetadataRequest which are 

279# used in simple consumers. 

280 

281 

282class MethodMetadataRequest: 

283 """A prescription of how metadata is to be passed to a single method. 

284 

285 Refer to :class:`MetadataRequest` for how this class is used. 

286 

287 .. versionadded:: 1.3 

288 

289 Parameters 

290 ---------- 

291 owner : str 

292 A display name for the object owning these requests. 

293 

294 method : str 

295 The name of the method to which these requests belong. 

296 

297 requests : dict of {str: bool, None or str}, default=None 

298 The initial requests for this method. 

299 """ 

300 

301 def __init__(self, owner, method, requests=None): 

302 self._requests = requests or dict() 

303 self.owner = owner 

304 self.method = method 

305 

306 @property 

307 def requests(self): 

308 """Dictionary of the form: ``{key: alias}``.""" 

309 return self._requests 

310 

311 def add_request( 

312 self, 

313 *, 

314 param, 

315 alias, 

316 ): 

317 """Add request info for a metadata. 

318 

319 Parameters 

320 ---------- 

321 param : str 

322 The property for which a request is set. 

323 

324 alias : str, or {True, False, None} 

325 Specifies which metadata should be routed to `param` 

326 

327 - str: the name (or alias) of metadata given to a meta-estimator that 

328 should be routed to this parameter. 

329 

330 - True: requested 

331 

332 - False: not requested 

333 

334 - None: error if passed 

335 """ 

336 if not request_is_alias(alias) and not request_is_valid(alias): 

337 raise ValueError( 

338 f"The alias you're setting for `{param}` should be either a " 

339 "valid identifier or one of {None, True, False}, but given " 

340 f"value is: `{alias}`" 

341 ) 

342 

343 if alias == param: 

344 alias = True 

345 

346 if alias == UNUSED: 

347 if param in self._requests: 

348 del self._requests[param] 

349 else: 

350 raise ValueError( 

351 f"Trying to remove parameter {param} with UNUSED which doesn't" 

352 " exist." 

353 ) 

354 else: 

355 self._requests[param] = alias 

356 

357 return self 

358 

359 def _get_param_names(self, return_alias): 

360 """Get names of all metadata that can be consumed or routed by this method. 

361 

362 This method returns the names of all metadata, even the ``False`` 

363 ones. 

364 

365 Parameters 

366 ---------- 

367 return_alias : bool 

368 Controls whether original or aliased names should be returned. If 

369 ``False``, aliases are ignored and original names are returned. 

370 

371 Returns 

372 ------- 

373 names : set of str 

374 A set of strings with the names of all parameters. 

375 """ 

376 return set( 

377 alias if return_alias and not request_is_valid(alias) else prop 

378 for prop, alias in self._requests.items() 

379 if not request_is_valid(alias) or alias is not False 

380 ) 

381 

382 def _check_warnings(self, *, params): 

383 """Check whether metadata is passed which is marked as WARN. 

384 

385 If any metadata is passed which is marked as WARN, a warning is raised. 

386 

387 Parameters 

388 ---------- 

389 params : dict 

390 The metadata passed to a method. 

391 """ 

392 params = {} if params is None else params 

393 warn_params = { 

394 prop 

395 for prop, alias in self._requests.items() 

396 if alias == WARN and prop in params 

397 } 

398 for param in warn_params: 

399 warn( 

400 f"Support for {param} has recently been added to this class. " 

401 "To maintain backward compatibility, it is ignored now. " 

402 "You can set the request value to False to silence this " 

403 "warning, or to True to consume and use the metadata." 

404 ) 

405 

406 def _route_params(self, params): 

407 """Prepare the given parameters to be passed to the method. 

408 

409 The output of this method can be used directly as the input to the 

410 corresponding method as extra props. 

411 

412 Parameters 

413 ---------- 

414 params : dict 

415 A dictionary of provided metadata. 

416 

417 Returns 

418 ------- 

419 params : Bunch 

420 A :class:`~sklearn.utils.Bunch` of {prop: value} which can be given to the 

421 corresponding method. 

422 """ 

423 self._check_warnings(params=params) 

424 unrequested = dict() 

425 args = {arg: value for arg, value in params.items() if value is not None} 

426 res = Bunch() 

427 for prop, alias in self._requests.items(): 

428 if alias is False or alias == WARN: 

429 continue 

430 elif alias is True and prop in args: 

431 res[prop] = args[prop] 

432 elif alias is None and prop in args: 

433 unrequested[prop] = args[prop] 

434 elif alias in args: 

435 res[prop] = args[alias] 

436 if unrequested: 

437 raise UnsetMetadataPassedError( 

438 message=( 

439 f"[{', '.join([key for key in unrequested])}] are passed but are" 

440 " not explicitly set as requested or not for" 

441 f" {self.owner}.{self.method}" 

442 ), 

443 unrequested_params=unrequested, 

444 routed_params=res, 

445 ) 

446 return res 

447 

448 def _consumes(self, params): 

449 """Check whether the given parameters are consumed by this method. 

450 

451 Parameters 

452 ---------- 

453 params : iterable of str 

454 An iterable of parameters to check. 

455 

456 Returns 

457 ------- 

458 consumed : set of str 

459 A set of parameters which are consumed by this method. 

460 """ 

461 params = set(params) 

462 res = set() 

463 for prop, alias in self._requests.items(): 

464 if alias is True and prop in params: 

465 res.add(prop) 

466 elif isinstance(alias, str) and alias in params: 

467 res.add(alias) 

468 return res 

469 

470 def _serialize(self): 

471 """Serialize the object. 

472 

473 Returns 

474 ------- 

475 obj : dict 

476 A serialized version of the instance in the form of a dictionary. 

477 """ 

478 return self._requests 

479 

480 def __repr__(self): 

481 return str(self._serialize()) 

482 

483 def __str__(self): 

484 return str(repr(self)) 

485 

486 

487class MetadataRequest: 

488 """Contains the metadata request info of a consumer. 

489 

490 Instances of `MethodMetadataRequest` are used in this class for each 

491 available method under `metadatarequest.{method}`. 

492 

493 Consumer-only classes such as simple estimators return a serialized 

494 version of this class as the output of `get_metadata_routing()`. 

495 

496 .. versionadded:: 1.3 

497 

498 Parameters 

499 ---------- 

500 owner : str 

501 The name of the object to which these requests belong. 

502 """ 

503 

504 # this is here for us to use this attribute's value instead of doing 

505 # `isinstance` in our checks, so that we avoid issues when people vendor 

506 # this file instead of using it directly from scikit-learn. 

507 _type = "metadata_request" 

508 

509 def __init__(self, owner): 

510 self.owner = owner 

511 for method in SIMPLE_METHODS: 

512 setattr( 

513 self, 

514 method, 

515 MethodMetadataRequest(owner=owner, method=method), 

516 ) 

517 

518 def consumes(self, method, params): 

519 """Check whether the given parameters are consumed by the given method. 

520 

521 .. versionadded:: 1.4 

522 

523 Parameters 

524 ---------- 

525 method : str 

526 The name of the method to check. 

527 

528 params : iterable of str 

529 An iterable of parameters to check. 

530 

531 Returns 

532 ------- 

533 consumed : set of str 

534 A set of parameters which are consumed by the given method. 

535 """ 

536 return getattr(self, method)._consumes(params=params) 

537 

538 def __getattr__(self, name): 

539 # Called when the default attribute access fails with an AttributeError 

540 # (either __getattribute__() raises an AttributeError because name is 

541 # not an instance attribute or an attribute in the class tree for self; 

542 # or __get__() of a name property raises AttributeError). This method 

543 # should either return the (computed) attribute value or raise an 

544 # AttributeError exception. 

545 # https://docs.python.org/3/reference/datamodel.html#object.__getattr__ 

546 if name not in COMPOSITE_METHODS: 

547 raise AttributeError( 

548 f"'{self.__class__.__name__}' object has no attribute '{name}'" 

549 ) 

550 

551 requests = {} 

552 for method in COMPOSITE_METHODS[name]: 

553 mmr = getattr(self, method) 

554 existing = set(requests.keys()) 

555 upcoming = set(mmr.requests.keys()) 

556 common = existing & upcoming 

557 conflicts = [key for key in common if requests[key] != mmr._requests[key]] 

558 if conflicts: 

559 raise ValueError( 

560 f"Conflicting metadata requests for {', '.join(conflicts)} while" 

561 f" composing the requests for {name}. Metadata with the same name" 

562 f" for methods {', '.join(COMPOSITE_METHODS[name])} should have the" 

563 " same request value." 

564 ) 

565 requests.update(mmr._requests) 

566 return MethodMetadataRequest(owner=self.owner, method=name, requests=requests) 

567 

568 def _get_param_names(self, method, return_alias, ignore_self_request=None): 

569 """Get names of all metadata that can be consumed or routed by specified \ 

570 method. 

571 

572 This method returns the names of all metadata, even the ``False`` 

573 ones. 

574 

575 Parameters 

576 ---------- 

577 method : str 

578 The name of the method for which metadata names are requested. 

579 

580 return_alias : bool 

581 Controls whether original or aliased names should be returned. If 

582 ``False``, aliases are ignored and original names are returned. 

583 

584 ignore_self_request : bool 

585 Ignored. Present for API compatibility. 

586 

587 Returns 

588 ------- 

589 names : set of str 

590 A set of strings with the names of all parameters. 

591 """ 

592 return getattr(self, method)._get_param_names(return_alias=return_alias) 

593 

594 def _route_params(self, *, method, params): 

595 """Prepare the given parameters to be passed to the method. 

596 

597 The output of this method can be used directly as the input to the 

598 corresponding method as extra keyword arguments to pass metadata. 

599 

600 Parameters 

601 ---------- 

602 method : str 

603 The name of the method for which the parameters are requested and 

604 routed. 

605 

606 params : dict 

607 A dictionary of provided metadata. 

608 

609 Returns 

610 ------- 

611 params : Bunch 

612 A :class:`~sklearn.utils.Bunch` of {prop: value} which can be given to the 

613 corresponding method. 

614 """ 

615 return getattr(self, method)._route_params(params=params) 

616 

617 def _check_warnings(self, *, method, params): 

618 """Check whether metadata is passed which is marked as WARN. 

619 

620 If any metadata is passed which is marked as WARN, a warning is raised. 

621 

622 Parameters 

623 ---------- 

624 method : str 

625 The name of the method for which the warnings should be checked. 

626 

627 params : dict 

628 The metadata passed to a method. 

629 """ 

630 getattr(self, method)._check_warnings(params=params) 

631 

632 def _serialize(self): 

633 """Serialize the object. 

634 

635 Returns 

636 ------- 

637 obj : dict 

638 A serialized version of the instance in the form of a dictionary. 

639 """ 

640 output = dict() 

641 for method in SIMPLE_METHODS: 

642 mmr = getattr(self, method) 

643 if len(mmr.requests): 

644 output[method] = mmr._serialize() 

645 return output 

646 

647 def __repr__(self): 

648 return str(self._serialize()) 

649 

650 def __str__(self): 

651 return str(repr(self)) 

652 

653 

654# Metadata Request for Routers 

655# ============================ 

656# This section includes all objects required for MetadataRouter which is used 

657# in routers, returned by their ``get_metadata_routing``. 

658 

659# This namedtuple is used to store a (mapping, routing) pair. Mapping is a 

660# MethodMapping object, and routing is the output of `get_metadata_routing`. 

661# MetadataRouter stores a collection of these namedtuples. 

662RouterMappingPair = namedtuple("RouterMappingPair", ["mapping", "router"]) 

663 

664# A namedtuple storing a single method route. A collection of these namedtuples 

665# is stored in a MetadataRouter. 

666MethodPair = namedtuple("MethodPair", ["callee", "caller"]) 

667 

668 

669class MethodMapping: 

670 """Stores the mapping between callee and caller methods for a router. 

671 

672 This class is primarily used in a ``get_metadata_routing()`` of a router 

673 object when defining the mapping between a sub-object (a sub-estimator or a 

674 scorer) to the router's methods. It stores a collection of ``Route`` 

675 namedtuples. 

676 

677 Iterating through an instance of this class will yield named 

678 ``MethodPair(callee, caller)`` tuples. 

679 

680 .. versionadded:: 1.3 

681 """ 

682 

683 def __init__(self): 

684 self._routes = [] 

685 

686 def __iter__(self): 

687 return iter(self._routes) 

688 

689 def add(self, *, callee, caller): 

690 """Add a method mapping. 

691 

692 Parameters 

693 ---------- 

694 callee : str 

695 Child object's method name. This method is called in ``caller``. 

696 

697 caller : str 

698 Parent estimator's method name in which the ``callee`` is called. 

699 

700 Returns 

701 ------- 

702 self : MethodMapping 

703 Returns self. 

704 """ 

705 if callee not in METHODS: 

706 raise ValueError( 

707 f"Given callee:{callee} is not a valid method. Valid methods are:" 

708 f" {METHODS}" 

709 ) 

710 if caller not in METHODS: 

711 raise ValueError( 

712 f"Given caller:{caller} is not a valid method. Valid methods are:" 

713 f" {METHODS}" 

714 ) 

715 self._routes.append(MethodPair(callee=callee, caller=caller)) 

716 return self 

717 

718 def _serialize(self): 

719 """Serialize the object. 

720 

721 Returns 

722 ------- 

723 obj : list 

724 A serialized version of the instance in the form of a list. 

725 """ 

726 result = list() 

727 for route in self._routes: 

728 result.append({"callee": route.callee, "caller": route.caller}) 

729 return result 

730 

731 @classmethod 

732 def from_str(cls, route): 

733 """Construct an instance from a string. 

734 

735 Parameters 

736 ---------- 

737 route : str 

738 A string representing the mapping, it can be: 

739 

740 - `"one-to-one"`: a one to one mapping for all methods. 

741 - `"method"`: the name of a single method, such as ``fit``, 

742 ``transform``, ``score``, etc. 

743 

744 Returns 

745 ------- 

746 obj : MethodMapping 

747 A :class:`~sklearn.utils.metadata_routing.MethodMapping` instance 

748 constructed from the given string. 

749 """ 

750 routing = cls() 

751 if route == "one-to-one": 

752 for method in METHODS: 

753 routing.add(callee=method, caller=method) 

754 elif route in METHODS: 

755 routing.add(callee=route, caller=route) 

756 else: 

757 raise ValueError("route should be 'one-to-one' or a single method!") 

758 return routing 

759 

760 def __repr__(self): 

761 return str(self._serialize()) 

762 

763 def __str__(self): 

764 return str(repr(self)) 

765 

766 

767class MetadataRouter: 

768 """Stores and handles metadata routing for a router object. 

769 

770 This class is used by router objects to store and handle metadata routing. 

771 Routing information is stored as a dictionary of the form ``{"object_name": 

772 RouteMappingPair(method_mapping, routing_info)}``, where ``method_mapping`` 

773 is an instance of :class:`~sklearn.utils.metadata_routing.MethodMapping` and 

774 ``routing_info`` is either a 

775 :class:`~sklearn.utils.metadata_routing.MetadataRequest` or a 

776 :class:`~sklearn.utils.metadata_routing.MetadataRouter` instance. 

777 

778 .. versionadded:: 1.3 

779 

780 Parameters 

781 ---------- 

782 owner : str 

783 The name of the object to which these requests belong. 

784 """ 

785 

786 # this is here for us to use this attribute's value instead of doing 

787 # `isinstance`` in our checks, so that we avoid issues when people vendor 

788 # this file instead of using it directly from scikit-learn. 

789 _type = "metadata_router" 

790 

791 def __init__(self, owner): 

792 self._route_mappings = dict() 

793 # `_self_request` is used if the router is also a consumer. 

794 # _self_request, (added using `add_self_request()`) is treated 

795 # differently from the other objects which are stored in 

796 # _route_mappings. 

797 self._self_request = None 

798 self.owner = owner 

799 

800 def add_self_request(self, obj): 

801 """Add `self` (as a consumer) to the routing. 

802 

803 This method is used if the router is also a consumer, and hence the 

804 router itself needs to be included in the routing. The passed object 

805 can be an estimator or a 

806 :class:`~sklearn.utils.metadata_routing.MetadataRequest`. 

807 

808 A router should add itself using this method instead of `add` since it 

809 should be treated differently than the other objects to which metadata 

810 is routed by the router. 

811 

812 Parameters 

813 ---------- 

814 obj : object 

815 This is typically the router instance, i.e. `self` in a 

816 ``get_metadata_routing()`` implementation. It can also be a 

817 ``MetadataRequest`` instance. 

818 

819 Returns 

820 ------- 

821 self : MetadataRouter 

822 Returns `self`. 

823 """ 

824 if getattr(obj, "_type", None) == "metadata_request": 

825 self._self_request = deepcopy(obj) 

826 elif hasattr(obj, "_get_metadata_request"): 

827 self._self_request = deepcopy(obj._get_metadata_request()) 

828 else: 

829 raise ValueError( 

830 "Given `obj` is neither a `MetadataRequest` nor does it implement the" 

831 " required API. Inheriting from `BaseEstimator` implements the required" 

832 " API." 

833 ) 

834 return self 

835 

836 def add(self, *, method_mapping, **objs): 

837 """Add named objects with their corresponding method mapping. 

838 

839 Parameters 

840 ---------- 

841 method_mapping : MethodMapping or str 

842 The mapping between the child and the parent's methods. If str, the 

843 output of :func:`~sklearn.utils.metadata_routing.MethodMapping.from_str` 

844 is used. 

845 

846 **objs : dict 

847 A dictionary of objects from which metadata is extracted by calling 

848 :func:`~sklearn.utils.metadata_routing.get_routing_for_object` on them. 

849 

850 Returns 

851 ------- 

852 self : MetadataRouter 

853 Returns `self`. 

854 """ 

855 if isinstance(method_mapping, str): 

856 method_mapping = MethodMapping.from_str(method_mapping) 

857 else: 

858 method_mapping = deepcopy(method_mapping) 

859 

860 for name, obj in objs.items(): 

861 self._route_mappings[name] = RouterMappingPair( 

862 mapping=method_mapping, router=get_routing_for_object(obj) 

863 ) 

864 return self 

865 

866 def consumes(self, method, params): 

867 """Check whether the given parameters are consumed by the given method. 

868 

869 .. versionadded:: 1.4 

870 

871 Parameters 

872 ---------- 

873 method : str 

874 The name of the method to check. 

875 

876 params : iterable of str 

877 An iterable of parameters to check. 

878 

879 Returns 

880 ------- 

881 consumed : set of str 

882 A set of parameters which are consumed by the given method. 

883 """ 

884 res = set() 

885 if self._self_request: 

886 res = res | self._self_request.consumes(method=method, params=params) 

887 

888 for _, route_mapping in self._route_mappings.items(): 

889 for callee, caller in route_mapping.mapping: 

890 if caller == method: 

891 res = res | route_mapping.router.consumes( 

892 method=callee, params=params 

893 ) 

894 

895 return res 

896 

897 def _get_param_names(self, *, method, return_alias, ignore_self_request): 

898 """Get names of all metadata that can be consumed or routed by specified \ 

899 method. 

900 

901 This method returns the names of all metadata, even the ``False`` 

902 ones. 

903 

904 Parameters 

905 ---------- 

906 method : str 

907 The name of the method for which metadata names are requested. 

908 

909 return_alias : bool 

910 Controls whether original or aliased names should be returned, 

911 which only applies to the stored `self`. If no `self` routing 

912 object is stored, this parameter has no effect. 

913 

914 ignore_self_request : bool 

915 If `self._self_request` should be ignored. This is used in `_route_params`. 

916 If ``True``, ``return_alias`` has no effect. 

917 

918 Returns 

919 ------- 

920 names : set of str 

921 A set of strings with the names of all parameters. 

922 """ 

923 res = set() 

924 if self._self_request and not ignore_self_request: 

925 res = res.union( 

926 self._self_request._get_param_names( 

927 method=method, return_alias=return_alias 

928 ) 

929 ) 

930 

931 for name, route_mapping in self._route_mappings.items(): 

932 for callee, caller in route_mapping.mapping: 

933 if caller == method: 

934 res = res.union( 

935 route_mapping.router._get_param_names( 

936 method=callee, return_alias=True, ignore_self_request=False 

937 ) 

938 ) 

939 return res 

940 

941 def _route_params(self, *, params, method): 

942 """Prepare the given parameters to be passed to the method. 

943 

944 This is used when a router is used as a child object of another router. 

945 The parent router then passes all parameters understood by the child 

946 object to it and delegates their validation to the child. 

947 

948 The output of this method can be used directly as the input to the 

949 corresponding method as extra props. 

950 

951 Parameters 

952 ---------- 

953 method : str 

954 The name of the method for which the parameters are requested and 

955 routed. 

956 

957 params : dict 

958 A dictionary of provided metadata. 

959 

960 Returns 

961 ------- 

962 params : Bunch 

963 A :class:`~sklearn.utils.Bunch` of {prop: value} which can be given to the 

964 corresponding method. 

965 """ 

966 res = Bunch() 

967 if self._self_request: 

968 res.update(self._self_request._route_params(params=params, method=method)) 

969 

970 param_names = self._get_param_names( 

971 method=method, return_alias=True, ignore_self_request=True 

972 ) 

973 child_params = { 

974 key: value for key, value in params.items() if key in param_names 

975 } 

976 for key in set(res.keys()).intersection(child_params.keys()): 

977 # conflicts are okay if the passed objects are the same, but it's 

978 # an issue if they're different objects. 

979 if child_params[key] is not res[key]: 

980 raise ValueError( 

981 f"In {self.owner}, there is a conflict on {key} between what is" 

982 " requested for this estimator and what is requested by its" 

983 " children. You can resolve this conflict by using an alias for" 

984 " the child estimator(s) requested metadata." 

985 ) 

986 

987 res.update(child_params) 

988 return res 

989 

990 def route_params(self, *, caller, params): 

991 """Return the input parameters requested by child objects. 

992 

993 The output of this method is a bunch, which includes the inputs for all 

994 methods of each child object that are used in the router's `caller` 

995 method. 

996 

997 If the router is also a consumer, it also checks for warnings of 

998 `self`'s/consumer's requested metadata. 

999 

1000 Parameters 

1001 ---------- 

1002 caller : str 

1003 The name of the method for which the parameters are requested and 

1004 routed. If called inside the :term:`fit` method of a router, it 

1005 would be `"fit"`. 

1006 

1007 params : dict 

1008 A dictionary of provided metadata. 

1009 

1010 Returns 

1011 ------- 

1012 params : Bunch 

1013 A :class:`~sklearn.utils.Bunch` of the form 

1014 ``{"object_name": {"method_name": {prop: value}}}`` which can be 

1015 used to pass the required metadata to corresponding methods or 

1016 corresponding child objects. 

1017 """ 

1018 if self._self_request: 

1019 self._self_request._check_warnings(params=params, method=caller) 

1020 

1021 res = Bunch() 

1022 for name, route_mapping in self._route_mappings.items(): 

1023 router, mapping = route_mapping.router, route_mapping.mapping 

1024 

1025 res[name] = Bunch() 

1026 for _callee, _caller in mapping: 

1027 if _caller == caller: 

1028 res[name][_callee] = router._route_params( 

1029 params=params, method=_callee 

1030 ) 

1031 return res 

1032 

1033 def validate_metadata(self, *, method, params): 

1034 """Validate given metadata for a method. 

1035 

1036 This raises a ``TypeError`` if some of the passed metadata are not 

1037 understood by child objects. 

1038 

1039 Parameters 

1040 ---------- 

1041 method : str 

1042 The name of the method for which the parameters are requested and 

1043 routed. If called inside the :term:`fit` method of a router, it 

1044 would be `"fit"`. 

1045 

1046 params : dict 

1047 A dictionary of provided metadata. 

1048 """ 

1049 param_names = self._get_param_names( 

1050 method=method, return_alias=False, ignore_self_request=False 

1051 ) 

1052 if self._self_request: 

1053 self_params = self._self_request._get_param_names( 

1054 method=method, return_alias=False 

1055 ) 

1056 else: 

1057 self_params = set() 

1058 extra_keys = set(params.keys()) - param_names - self_params 

1059 if extra_keys: 

1060 raise TypeError( 

1061 f"{self.owner}.{method} got unexpected argument(s) {extra_keys}, which" 

1062 " are not requested metadata in any object." 

1063 ) 

1064 

1065 def _serialize(self): 

1066 """Serialize the object. 

1067 

1068 Returns 

1069 ------- 

1070 obj : dict 

1071 A serialized version of the instance in the form of a dictionary. 

1072 """ 

1073 res = dict() 

1074 if self._self_request: 

1075 res["$self_request"] = self._self_request._serialize() 

1076 for name, route_mapping in self._route_mappings.items(): 

1077 res[name] = dict() 

1078 res[name]["mapping"] = route_mapping.mapping._serialize() 

1079 res[name]["router"] = route_mapping.router._serialize() 

1080 

1081 return res 

1082 

1083 def __iter__(self): 

1084 if self._self_request: 

1085 yield "$self_request", RouterMappingPair( 

1086 mapping=MethodMapping.from_str("one-to-one"), router=self._self_request 

1087 ) 

1088 for name, route_mapping in self._route_mappings.items(): 

1089 yield (name, route_mapping) 

1090 

1091 def __repr__(self): 

1092 return str(self._serialize()) 

1093 

1094 def __str__(self): 

1095 return str(repr(self)) 

1096 

1097 

1098def get_routing_for_object(obj=None): 

1099 """Get a ``Metadata{Router, Request}`` instance from the given object. 

1100 

1101 This function returns a 

1102 :class:`~sklearn.utils.metadata_routing.MetadataRouter` or a 

1103 :class:`~sklearn.utils.metadata_routing.MetadataRequest` from the given input. 

1104 

1105 This function always returns a copy or an instance constructed from the 

1106 input, such that changing the output of this function will not change the 

1107 original object. 

1108 

1109 .. versionadded:: 1.3 

1110 

1111 Parameters 

1112 ---------- 

1113 obj : object 

1114 - If the object is already a 

1115 :class:`~sklearn.utils.metadata_routing.MetadataRequest` or a 

1116 :class:`~sklearn.utils.metadata_routing.MetadataRouter`, return a copy 

1117 of that. 

1118 - If the object provides a `get_metadata_routing` method, return a copy 

1119 of the output of that method. 

1120 - Returns an empty :class:`~sklearn.utils.metadata_routing.MetadataRequest` 

1121 otherwise. 

1122 

1123 Returns 

1124 ------- 

1125 obj : MetadataRequest or MetadataRouting 

1126 A ``MetadataRequest`` or a ``MetadataRouting`` taken or created from 

1127 the given object. 

1128 """ 

1129 # doing this instead of a try/except since an AttributeError could be raised 

1130 # for other reasons. 

1131 if hasattr(obj, "get_metadata_routing"): 

1132 return deepcopy(obj.get_metadata_routing()) 

1133 

1134 elif getattr(obj, "_type", None) in ["metadata_request", "metadata_router"]: 

1135 return deepcopy(obj) 

1136 

1137 return MetadataRequest(owner=None) 

1138 

1139 

1140# Request method 

1141# ============== 

1142# This section includes what's needed for the request method descriptor and 

1143# their dynamic generation in a meta class. 

1144 

1145# These strings are used to dynamically generate the docstrings for 

1146# set_{method}_request methods. 

1147REQUESTER_DOC = """ Request metadata passed to the ``{method}`` method. 

1148 

1149 Note that this method is only relevant if 

1150 ``enable_metadata_routing=True`` (see :func:`sklearn.set_config`). 

1151 Please see :ref:`User Guide <metadata_routing>` on how the routing 

1152 mechanism works. 

1153 

1154 The options for each parameter are: 

1155 

1156 - ``True``: metadata is requested, and \ 

1157passed to ``{method}`` if provided. The request is ignored if \ 

1158metadata is not provided. 

1159 

1160 - ``False``: metadata is not requested and the meta-estimator \ 

1161will not pass it to ``{method}``. 

1162 

1163 - ``None``: metadata is not requested, and the meta-estimator \ 

1164will raise an error if the user provides it. 

1165 

1166 - ``str``: metadata should be passed to the meta-estimator with \ 

1167this given alias instead of the original name. 

1168 

1169 The default (``sklearn.utils.metadata_routing.UNCHANGED``) retains the 

1170 existing request. This allows you to change the request for some 

1171 parameters and not others. 

1172 

1173 .. versionadded:: 1.3 

1174 

1175 .. note:: 

1176 This method is only relevant if this estimator is used as a 

1177 sub-estimator of a meta-estimator, e.g. used inside a 

1178 :class:`~sklearn.pipeline.Pipeline`. Otherwise it has no effect. 

1179 

1180 Parameters 

1181 ---------- 

1182""" 

1183REQUESTER_DOC_PARAM = """ {metadata} : str, True, False, or None, \ 

1184 default=sklearn.utils.metadata_routing.UNCHANGED 

1185 Metadata routing for ``{metadata}`` parameter in ``{method}``. 

1186 

1187""" 

1188REQUESTER_DOC_RETURN = """ Returns 

1189 ------- 

1190 self : object 

1191 The updated object. 

1192""" 

1193 

1194 

1195class RequestMethod: 

1196 """ 

1197 A descriptor for request methods. 

1198 

1199 .. versionadded:: 1.3 

1200 

1201 Parameters 

1202 ---------- 

1203 name : str 

1204 The name of the method for which the request function should be 

1205 created, e.g. ``"fit"`` would create a ``set_fit_request`` function. 

1206 

1207 keys : list of str 

1208 A list of strings which are accepted parameters by the created 

1209 function, e.g. ``["sample_weight"]`` if the corresponding method 

1210 accepts it as a metadata. 

1211 

1212 validate_keys : bool, default=True 

1213 Whether to check if the requested parameters fit the actual parameters 

1214 of the method. 

1215 

1216 Notes 

1217 ----- 

1218 This class is a descriptor [1]_ and uses PEP-362 to set the signature of 

1219 the returned function [2]_. 

1220 

1221 References 

1222 ---------- 

1223 .. [1] https://docs.python.org/3/howto/descriptor.html 

1224 

1225 .. [2] https://www.python.org/dev/peps/pep-0362/ 

1226 """ 

1227 

1228 def __init__(self, name, keys, validate_keys=True): 

1229 self.name = name 

1230 self.keys = keys 

1231 self.validate_keys = validate_keys 

1232 

1233 def __get__(self, instance, owner): 

1234 # we would want to have a method which accepts only the expected args 

1235 def func(**kw): 

1236 """Updates the request for provided parameters 

1237 

1238 This docstring is overwritten below. 

1239 See REQUESTER_DOC for expected functionality 

1240 """ 

1241 if not _routing_enabled(): 

1242 raise RuntimeError( 

1243 "This method is only available when metadata routing is enabled." 

1244 " You can enable it using" 

1245 " sklearn.set_config(enable_metadata_routing=True)." 

1246 ) 

1247 

1248 if self.validate_keys and (set(kw) - set(self.keys)): 

1249 raise TypeError( 

1250 f"Unexpected args: {set(kw) - set(self.keys)}. Accepted arguments" 

1251 f" are: {set(self.keys)}" 

1252 ) 

1253 

1254 requests = instance._get_metadata_request() 

1255 method_metadata_request = getattr(requests, self.name) 

1256 

1257 for prop, alias in kw.items(): 

1258 if alias is not UNCHANGED: 

1259 method_metadata_request.add_request(param=prop, alias=alias) 

1260 instance._metadata_request = requests 

1261 

1262 return instance 

1263 

1264 # Now we set the relevant attributes of the function so that it seems 

1265 # like a normal method to the end user, with known expected arguments. 

1266 func.__name__ = f"set_{self.name}_request" 

1267 params = [ 

1268 inspect.Parameter( 

1269 name="self", 

1270 kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, 

1271 annotation=owner, 

1272 ) 

1273 ] 

1274 params.extend( 

1275 [ 

1276 inspect.Parameter( 

1277 k, 

1278 inspect.Parameter.KEYWORD_ONLY, 

1279 default=UNCHANGED, 

1280 annotation=Optional[Union[bool, None, str]], 

1281 ) 

1282 for k in self.keys 

1283 ] 

1284 ) 

1285 func.__signature__ = inspect.Signature( 

1286 params, 

1287 return_annotation=owner, 

1288 ) 

1289 doc = REQUESTER_DOC.format(method=self.name) 

1290 for metadata in self.keys: 

1291 doc += REQUESTER_DOC_PARAM.format(metadata=metadata, method=self.name) 

1292 doc += REQUESTER_DOC_RETURN 

1293 func.__doc__ = doc 

1294 return func 

1295 

1296 

1297class _MetadataRequester: 

1298 """Mixin class for adding metadata request functionality. 

1299 

1300 ``BaseEstimator`` inherits from this Mixin. 

1301 

1302 .. versionadded:: 1.3 

1303 """ 

1304 

1305 if TYPE_CHECKING: # pragma: no cover 

1306 # This code is never run in runtime, but it's here for type checking. 

1307 # Type checkers fail to understand that the `set_{method}_request` 

1308 # methods are dynamically generated, and they complain that they are 

1309 # not defined. We define them here to make type checkers happy. 

1310 # During type checking analyzers assume this to be True. 

1311 # The following list of defined methods mirrors the list of methods 

1312 # in SIMPLE_METHODS. 

1313 # fmt: off 

1314 def set_fit_request(self, **kwargs): pass 

1315 def set_partial_fit_request(self, **kwargs): pass 

1316 def set_predict_request(self, **kwargs): pass 

1317 def set_predict_proba_request(self, **kwargs): pass 

1318 def set_predict_log_proba_request(self, **kwargs): pass 

1319 def set_decision_function_request(self, **kwargs): pass 

1320 def set_score_request(self, **kwargs): pass 

1321 def set_split_request(self, **kwargs): pass 

1322 def set_transform_request(self, **kwargs): pass 

1323 def set_inverse_transform_request(self, **kwargs): pass 

1324 # fmt: on 

1325 

1326 def __init_subclass__(cls, **kwargs): 

1327 """Set the ``set_{method}_request`` methods. 

1328 

1329 This uses PEP-487 [1]_ to set the ``set_{method}_request`` methods. It 

1330 looks for the information available in the set default values which are 

1331 set using ``__metadata_request__*`` class attributes, or inferred 

1332 from method signatures. 

1333 

1334 The ``__metadata_request__*`` class attributes are used when a method 

1335 does not explicitly accept a metadata through its arguments or if the 

1336 developer would like to specify a request value for those metadata 

1337 which are different from the default ``None``. 

1338 

1339 References 

1340 ---------- 

1341 .. [1] https://www.python.org/dev/peps/pep-0487 

1342 """ 

1343 try: 

1344 requests = cls._get_default_requests() 

1345 except Exception: 

1346 # if there are any issues in the default values, it will be raised 

1347 # when ``get_metadata_routing`` is called. Here we are going to 

1348 # ignore all the issues such as bad defaults etc. 

1349 super().__init_subclass__(**kwargs) 

1350 return 

1351 

1352 for method in SIMPLE_METHODS: 

1353 mmr = getattr(requests, method) 

1354 # set ``set_{method}_request``` methods 

1355 if not len(mmr.requests): 

1356 continue 

1357 setattr( 

1358 cls, 

1359 f"set_{method}_request", 

1360 RequestMethod(method, sorted(mmr.requests.keys())), 

1361 ) 

1362 super().__init_subclass__(**kwargs) 

1363 

1364 @classmethod 

1365 def _build_request_for_signature(cls, router, method): 

1366 """Build the `MethodMetadataRequest` for a method using its signature. 

1367 

1368 This method takes all arguments from the method signature and uses 

1369 ``None`` as their default request value, except ``X``, ``y``, ``Y``, 

1370 ``Xt``, ``yt``, ``*args``, and ``**kwargs``. 

1371 

1372 Parameters 

1373 ---------- 

1374 router : MetadataRequest 

1375 The parent object for the created `MethodMetadataRequest`. 

1376 method : str 

1377 The name of the method. 

1378 

1379 Returns 

1380 ------- 

1381 method_request : MethodMetadataRequest 

1382 The prepared request using the method's signature. 

1383 """ 

1384 mmr = MethodMetadataRequest(owner=cls.__name__, method=method) 

1385 # Here we use `isfunction` instead of `ismethod` because calling `getattr` 

1386 # on a class instead of an instance returns an unbound function. 

1387 if not hasattr(cls, method) or not inspect.isfunction(getattr(cls, method)): 

1388 return mmr 

1389 # ignore the first parameter of the method, which is usually "self" 

1390 params = list(inspect.signature(getattr(cls, method)).parameters.items())[1:] 

1391 for pname, param in params: 

1392 if pname in {"X", "y", "Y", "Xt", "yt"}: 

1393 continue 

1394 if param.kind in {param.VAR_POSITIONAL, param.VAR_KEYWORD}: 

1395 continue 

1396 mmr.add_request( 

1397 param=pname, 

1398 alias=None, 

1399 ) 

1400 return mmr 

1401 

1402 @classmethod 

1403 def _get_default_requests(cls): 

1404 """Collect default request values. 

1405 

1406 This method combines the information present in ``__metadata_request__*`` 

1407 class attributes, as well as determining request keys from method 

1408 signatures. 

1409 """ 

1410 requests = MetadataRequest(owner=cls.__name__) 

1411 

1412 for method in SIMPLE_METHODS: 

1413 setattr( 

1414 requests, 

1415 method, 

1416 cls._build_request_for_signature(router=requests, method=method), 

1417 ) 

1418 

1419 # Then overwrite those defaults with the ones provided in 

1420 # __metadata_request__* attributes. Defaults set in 

1421 # __metadata_request__* attributes take precedence over signature 

1422 # sniffing. 

1423 

1424 # need to go through the MRO since this is a class attribute and 

1425 # ``vars`` doesn't report the parent class attributes. We go through 

1426 # the reverse of the MRO so that child classes have precedence over 

1427 # their parents. 

1428 defaults = dict() 

1429 for base_class in reversed(inspect.getmro(cls)): 

1430 base_defaults = { 

1431 attr: value 

1432 for attr, value in vars(base_class).items() 

1433 if "__metadata_request__" in attr 

1434 } 

1435 defaults.update(base_defaults) 

1436 defaults = dict(sorted(defaults.items())) 

1437 

1438 for attr, value in defaults.items(): 

1439 # we don't check for attr.startswith() since python prefixes attrs 

1440 # starting with __ with the `_ClassName`. 

1441 substr = "__metadata_request__" 

1442 method = attr[attr.index(substr) + len(substr) :] 

1443 for prop, alias in value.items(): 

1444 getattr(requests, method).add_request(param=prop, alias=alias) 

1445 

1446 return requests 

1447 

1448 def _get_metadata_request(self): 

1449 """Get requested data properties. 

1450 

1451 Please check :ref:`User Guide <metadata_routing>` on how the routing 

1452 mechanism works. 

1453 

1454 Returns 

1455 ------- 

1456 request : MetadataRequest 

1457 A :class:`~sklearn.utils.metadata_routing.MetadataRequest` instance. 

1458 """ 

1459 if hasattr(self, "_metadata_request"): 

1460 requests = get_routing_for_object(self._metadata_request) 

1461 else: 

1462 requests = self._get_default_requests() 

1463 

1464 return requests 

1465 

1466 def get_metadata_routing(self): 

1467 """Get metadata routing of this object. 

1468 

1469 Please check :ref:`User Guide <metadata_routing>` on how the routing 

1470 mechanism works. 

1471 

1472 Returns 

1473 ------- 

1474 routing : MetadataRequest 

1475 A :class:`~sklearn.utils.metadata_routing.MetadataRequest` encapsulating 

1476 routing information. 

1477 """ 

1478 return self._get_metadata_request() 

1479 

1480 

1481# Process Routing in Routers 

1482# ========================== 

1483# This is almost always the only method used in routers to process and route 

1484# given metadata. This is to minimize the boilerplate required in routers. 

1485 

1486 

1487# Here the first two arguments are positional only which makes everything 

1488# passed as keyword argument a metadata. The first two args also have an `_` 

1489# prefix to reduce the chances of name collisions with the passed metadata, and 

1490# since they're positional only, users will never type those underscores. 

1491def process_routing(_obj, _method, /, **kwargs): 

1492 """Validate and route input parameters. 

1493 

1494 This function is used inside a router's method, e.g. :term:`fit`, 

1495 to validate the metadata and handle the routing. 

1496 

1497 Assuming this signature: ``fit(self, X, y, sample_weight=None, **fit_params)``, 

1498 a call to this function would be: 

1499 ``process_routing(self, sample_weight=sample_weight, **fit_params)``. 

1500 

1501 Note that if routing is not enabled and ``kwargs`` is empty, then it 

1502 returns an empty routing where ``process_routing(...).ANYTHING.ANY_METHOD`` 

1503 is always an empty dictionary. 

1504 

1505 .. versionadded:: 1.3 

1506 

1507 Parameters 

1508 ---------- 

1509 _obj : object 

1510 An object implementing ``get_metadata_routing``. Typically a 

1511 meta-estimator. 

1512 

1513 _method : str 

1514 The name of the router's method in which this function is called. 

1515 

1516 **kwargs : dict 

1517 Metadata to be routed. 

1518 

1519 Returns 

1520 ------- 

1521 routed_params : Bunch 

1522 A :class:`~sklearn.utils.Bunch` of the form ``{"object_name": {"method_name": 

1523 {prop: value}}}`` which can be used to pass the required metadata to 

1524 corresponding methods or corresponding child objects. The object names 

1525 are those defined in `obj.get_metadata_routing()`. 

1526 """ 

1527 if not _routing_enabled() and not kwargs: 

1528 # If routing is not enabled and kwargs are empty, then we don't have to 

1529 # try doing any routing, we can simply return a structure which returns 

1530 # an empty dict on routed_params.ANYTHING.ANY_METHOD. 

1531 class EmptyRequest: 

1532 def get(self, name, default=None): 

1533 return default if default else {} 

1534 

1535 def __getitem__(self, name): 

1536 return Bunch(**{method: dict() for method in METHODS}) 

1537 

1538 def __getattr__(self, name): 

1539 return Bunch(**{method: dict() for method in METHODS}) 

1540 

1541 return EmptyRequest() 

1542 

1543 if not (hasattr(_obj, "get_metadata_routing") or isinstance(_obj, MetadataRouter)): 

1544 raise AttributeError( 

1545 f"The given object ({repr(_obj.__class__.__name__)}) needs to either" 

1546 " implement the routing method `get_metadata_routing` or be a" 

1547 " `MetadataRouter` instance." 

1548 ) 

1549 if _method not in METHODS: 

1550 raise TypeError( 

1551 f"Can only route and process input on these methods: {METHODS}, " 

1552 f"while the passed method is: {_method}." 

1553 ) 

1554 

1555 request_routing = get_routing_for_object(_obj) 

1556 request_routing.validate_metadata(params=kwargs, method=_method) 

1557 routed_params = request_routing.route_params(params=kwargs, caller=_method) 

1558 

1559 return routed_params