@@ -660,6 +660,8 @@ def fit(self, X, Y, **fit_params):
660
660
661
661
del Y_pred_chain
662
662
663
+ routed_params = process_routing (obj = self , method = "fit" , other_params = fit_params )
664
+
663
665
for chain_idx , estimator in enumerate (self .estimators_ ):
664
666
message = self ._log_message (
665
667
estimator_idx = chain_idx + 1 ,
@@ -668,7 +670,12 @@ def fit(self, X, Y, **fit_params):
668
670
)
669
671
y = Y [:, self .order_ [chain_idx ]]
670
672
with _print_elapsed_time ("Chain" , message ):
671
- estimator .fit (X_aug [:, : (X .shape [1 ] + chain_idx )], y , ** fit_params )
673
+ estimator .fit (
674
+ X_aug [:, : (X .shape [1 ] + chain_idx )],
675
+ y ,
676
+ ** routed_params .estimator .fit ,
677
+ )
678
+
672
679
if self .cv is not None and chain_idx < len (self .estimators_ ) - 1 :
673
680
col_idx = X .shape [1 ] + chain_idx
674
681
cv_result = cross_val_predict (
@@ -831,7 +838,7 @@ class labels for each estimator in the chain.
831
838
[0.0321..., 0.9935..., 0.0625...]])
832
839
"""
833
840
834
- def fit (self , X , Y ):
841
+ def fit (self , X , Y , ** fit_params ):
835
842
"""Fit the model to data matrix X and targets Y.
836
843
837
844
Parameters
@@ -842,14 +849,19 @@ def fit(self, X, Y):
842
849
Y : array-like of shape (n_samples, n_classes)
843
850
The target values.
844
851
852
+ **fit_params : dict of string -> object
853
+ Parameters passed to the `fit` method of each step.
854
+
855
+ .. versionadded:: 1.2
856
+
845
857
Returns
846
858
-------
847
859
self : object
848
860
Class instance.
849
861
"""
850
862
self ._validate_params ()
851
863
852
- super ().fit (X , Y )
864
+ super ().fit (X , Y , ** fit_params )
853
865
self .classes_ = [
854
866
estimator .classes_ for chain_idx , estimator in enumerate (self .estimators_ )
855
867
]
@@ -919,6 +931,24 @@ def decision_function(self, X):
919
931
920
932
return Y_decision
921
933
934
+ def get_metadata_routing (self ):
935
+ """Get metadata routing of this object.
936
+
937
+ Please check :ref:`User Guide <metadata_routing>` on how the routing
938
+ mechanism works.
939
+
940
+ Returns
941
+ -------
942
+ routing : MetadataRouter
943
+ A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
944
+ routing information.
945
+ """
946
+ router = MetadataRouter (owner = self .__class__ .__name__ ).add (
947
+ estimator = self .base_estimator ,
948
+ method_mapping = MethodMapping ().add (callee = "fit" , caller = "fit" ),
949
+ )
950
+ return router
951
+
922
952
def _more_tags (self ):
923
953
return {"_skip_test" : True , "multioutput_only" : True }
924
954
@@ -1046,5 +1076,27 @@ def fit(self, X, Y, **fit_params):
1046
1076
super ().fit (X , Y , ** fit_params )
1047
1077
return self
1048
1078
1079
+ def get_metadata_routing (self ):
1080
+ """Get metadata routing of this object.
1081
+
1082
+ Please check :ref:`User Guide <metadata_routing>` on how the routing
1083
+ mechanism works.
1084
+
1085
+ Returns
1086
+ -------
1087
+ routing : MetadataRouter
1088
+ A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
1089
+ routing information.
1090
+ """
1091
+ router = (
1092
+ MetadataRouter (owner = self .__class__ .__name__ )
1093
+ .add (
1094
+ estimator = self .base_estimator ,
1095
+ method_mapping = MethodMapping ().add (callee = "fit" , caller = "fit" ),
1096
+ )
1097
+ .warn_on (child = "estimator" , method = "fit" , params = None )
1098
+ )
1099
+ return router
1100
+
1049
1101
def _more_tags (self ):
1050
1102
return {"multioutput_only" : True }
0 commit comments