8000 [MRG+1] FIX Add missing mixins to ClassifierChain (#9473) · AishwaryaRK/scikit-learn@515fadd · GitHub
[go: up one dir, main page]

Skip to content

Commit 515fadd

jnothmanAishwaryaRK
authored andcommitted
[MRG+1] FIX Add missing mixins to ClassifierChain (scikit-learn#9473)
* Add missing mixins to ClassifierChain * Fix import in test
1 parent 76b0479 commit 515fadd

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

sklearn/multioutput.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def score(self, X, y):
368368
return np.mean(np.all(y == y_pred, axis=1))
369369

370370

371-
class ClassifierChain(BaseEstimator):
371+
class ClassifierChain(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
372372
"""A multi-label model that arranges binary classifiers into a chain.
373373
374374
Each model makes a prediction in the order specified by the chain using

sklearn/tests/test_multioutput.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sklearn.multioutput import MultiOutputClassifier
3030
from sklearn.multioutput import MultiOutputRegressor
3131
from sklearn.svm import LinearSVC
32+
from sklearn.base import ClassifierMixin
3233
from sklearn.utils import shuffle
3334

3435

@@ -380,6 +381,8 @@ def test_classifier_chain_fit_and_predict_with_logistic_regression():
380381
assert_equal([c.coef_.size for c in classifier_chain.estimators_],
381382
list(range(X.shape[1], X.shape[1] + Y.shape[1])))
382383

384+
assert isinstance(classifier_chain, ClassifierMixin)
385+
383386

384387
def test_classifier_chain_fit_and_predict_with_linear_svc():
385388
# Fit classifier chain and verify predict performance using LinearSVC

0 commit comments

Comments
 (0)
0