8000 [MRG+1] Regressor chain tags (#13337) · scikit-learn/scikit-learn@bf6949b · GitHub
[go: up one dir, main page]

Skip to content

Commit bf6949b

Browse files
amuellerGaelVaroquaux
authored andcommitted
[MRG+1] Regressor chain tags (#13337)
* don't skip chain estimator in common tests * fix fitted check in chain estimators * actually still skip on classifier chain
1 parent ac79dff commit bf6949b

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

sklearn/multioutput.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def predict(self, X):
460460
The predicted values.
461461
462462
"""
463+
check_is_fitted(self, 'estimators_')
463464
X = check_array(X, accept_sparse=True)
464465
Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_)))
465466
for chain_idx, estimator in enumerate(self.estimators_):
@@ -636,7 +637,8 @@ def decision_function(self, X):
636637
return Y_decision
637638

638639
def _more_tags(self):
639-
return {'_skip_test': True}
640+
return {'_skip_test': True,
641+
'multioutput_only': True}
640642

641643

642644
class RegressorChain(_BaseChain, RegressorMixin, MetaEstimatorMixin):
@@ -722,5 +724,4 @@ def fit(self, X, Y):
722724
return self
723725

724726
def _more_tags(self):
725-
# FIXME
726-
return {'_skip_test': True}
727+
return {'multioutput_only': True}

0 commit comments

Comments
 (0)
0