8000 FIX OvR/OvO classifier decision_function shape fixes (#9100) · scikit-learn/scikit-learn@b43c791 · GitHub
[go: up one dir, main page]

Skip to content

Commit b43c791

Browse files
amuellervene
authored andcommitted
FIX OvR/OvO classifier decision_function shape fixes (#9100)
* fix OVR classifier edgecase bugs * add regression tests for OVO and OVR decision function shapes
1 parent 56a21ea commit b43c791

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,11 @@ API changes summary
418418
:func:`model_selection.cross_val_predict`.
419419
:issue:`2879` by :user:`Stephen Hoover <stephen-hoover>`.
420420

421+
- The ``decision_function`` output shape for binary classification in
422+
:class:`multi_class.OneVsRestClassifier` and
423+
:class:`multi_class.OneVsOneClassifier` is now ``(n_samples,)`` to conform
424+
to scikit-learn conventions. :issue:`9100` by `Andreas Müller`_.
425+
421426
- Gradient boosting base models are no longer estimators. By `Andreas Müller`_.
422427

423428
- :class:`feature_selection.SelectFromModel` now validates the ``threshold``

sklearn/multiclass.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ def decision_function(self, X):
368368
T : array-like, shape = [n_samples, n_classes]
369369
"""
370370
check_is_fitted(self, 'estimators_')
371+
if len(self.estimators_) == 1:
372+
return self.estimators_[0].decision_function(X)
371373
return np.array([est.decision_function(X).ravel()
372374
for est in self.estimators_]).T
373375

@@ -574,6 +576,8 @@ def predict(self, X):
574576
Predicted multi-class targets.
575577
"""
576578
Y = self.decision_function(X)
579+
if self.n_classes_ == 2:
580+
return self.classes_[(Y > 0).astype(np.int)]
577581
return self.classes_[Y.argmax(axis=1)]
578582

579583
def decision_function(self, X):
@@ -606,7 +610,8 @@ def decision_function(self, X):
606610
for est, Xi in zip(self.estimators_, Xs)]).T
607611
Y = _ovr_decision_function(predictions,
608612
confidences, len(self.classes_))
609-
613+
if self.n_classes_ == 2:
614+
return Y[:, 1]
610615
return Y
611616

612617
@property

sklearn/tests/test_multiclass.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ def conduct_test(base_clf, test_predict_proba=False):
251251
assert_equal(set(clf.classes_), classes)
252252
y_pred = clf.predict(np.array([[0, 0, 4]]))[0]
253253
assert_equal(set(y_pred), set("eggs"))
254+
if hasattr(base_clf, 'decision_function'):
255+
dec = clf.decision_function(X)
256+
assert_equal(dec.shape, (5,))
254257

255258
if test_predict_proba:
256259
X_test = np.array([[0, 0, 4]])
@@ -524,6 +527,12 @@ def test_ovo_decision_function():
524527
n_samples = iris.data.shape[0]
525528

526529
ovo_clf = OneVsOneClassifier(LinearSVC(random_state=0))
530+
# first binary
531+
ovo_clf.fit(iris.data, iris.target == 0)
532+
decisions = ovo_clf.decision_function(iris.data)
533+
assert_equal(decisions.shape, (n_samples,))
534+
535+
# then multi-class
527536
ovo_clf.fit(iris.data, iris.target)
528537
decisions = ovo_clf.decision_function(iris.data)
529538

0 commit comments

Comments
 (0)
0