From 5ef4b2e11ae494f30157b50f72d31802053dca46 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Sat, 10 Jun 2017 16:08:36 +0200 Subject: [PATCH 1/5] fix OVR classifier edgecase bugs --- sklearn/multiclass.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 712e8573fa469..59a17dddda538 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -368,6 +368,8 @@ def decision_function(self, X): T : array-like, shape = [n_samples, n_classes] """ check_is_fitted(self, 'estimators_') + if len(self.estimators_) == 1: + return self.estimators_[0].decision_function(X) return np.array([est.decision_function(X).ravel() for est in self.estimators_]).T @@ -574,6 +576,8 @@ def predict(self, X): Predicted multi-class targets. """ Y = self.decision_function(X) + if self.n_classes_ == 2: + return self.classes_[(Y > 0).astype(np.int)] return self.classes_[Y.argmax(axis=1)] def decision_function(self, X): @@ -606,7 +610,8 @@ def decision_function(self, X): for est, Xi in zip(self.estimators_, Xs)]).T Y = _ovr_decision_function(predictions, confidences, len(self.classes_)) - + if self.n_classes_ == 2: + return Y[:, 1] return Y @property From 3b675d111d527935ff989e60b8ac54dedb808bde Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Sat, 10 Jun 2017 16:21:34 +0200 Subject: [PATCH 2/5] add regression tests for OVO and OVR decision function shapes --- sklearn/tests/test_multiclass.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index 20ec4b132fc7f..b74a34ad63d77 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -251,6 +251,8 @@ def conduct_test(base_clf, test_predict_proba=False): assert_equal(set(clf.classes_), classes) y_pred = clf.predict(np.array([[0, 0, 4]]))[0] assert_equal(set(y_pred), set("eggs")) + dec = clf.decision_function(X) + assert_equal(dec.shape, (5,)) if test_predict_proba: X_test = np.array([[0, 0, 4]]) @@ -524,6 +526,12 @@ def test_ovo_decision_function(): n_samples = iris.data.shape[0] ovo_clf = OneVsOneClassifier(LinearSVC(random_state=0)) + # first binary + ovo_clf.fit(iris.data, iris.target == 0) + decisions = ovo_clf.decision_function(iris.data) + assert_equal(decisions.shape, (n_samples,)) + + # then multi-class ovo_clf.fit(iris.data, iris.target) decisions = ovo_clf.decision_function(iris.data) From 346aa4bd59783032de2328e13a6f6f55992953aa Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Sat, 10 Jun 2017 16:25:43 +0200 Subject: [PATCH 3/5] add whatsnew entry --- doc/whats_new.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index c0c470f625f6d..a88a69c9b80b8 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -413,6 +413,11 @@ API changes summary :func:`model_selection.cross_val_predict`. :issue:`2879` by :user:`Stephen Hoover `. + - The ``decision_function`` output shape for binary classification in + :class:`multi_class.OneVsRestClassifier` and + :class:`multi_class.OneVsOneClassifier` is now ``(n_samples,)`` to conform + to scikit-learn conventions. :issue:`9100` by `Andreas Müller`_. + - Gradient boosting base models are no longer estimators. By `Andreas Müller`_. - :class:`feature_selection.SelectFromModel` now validates the ``threshold`` From f031817d8574b2dc6f37f75285b16b9be1ec4f9d Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Sat, 10 Jun 2017 16:51:07 +0200 Subject: [PATCH 4/5] make test of decision_function conditional on whether there's a decision_function --- sklearn/tests/test_multiclass.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index b74a34ad63d77..971b26ae99357 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -251,7 +251,8 @@ def conduct_test(base_clf, test_predict_proba=False): assert_equal(set(clf.classes_), classes) y_pred = clf.predict(np.array([[0, 0, 4]]))[0] assert_equal(set(y_pred), set("eggs")) - dec = clf.decision_function(X) + if hasattr(base_clf, 'decision_function'): + dec = clf.decision_function(X) assert_equal(dec.shape, (5,)) if test_predict_proba: From 49987779d51a49df10c97adc49191a5d2fb983b5 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Sat, 10 Jun 2017 17:11:10 +0200 Subject: [PATCH 5/5] gah fix indentation --- sklearn/tests/test_multiclass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index 971b26ae99357..8e1c760555542 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -253,7 +253,7 @@ def conduct_test(base_clf, test_predict_proba=False): assert_equal(set(y_pred), set("eggs")) if hasattr(base_clf, 'decision_function'): dec = clf.decision_function(X) - assert_equal(dec.shape, (5,)) + assert_equal(dec.shape, (5,)) if test_predict_proba: X_test = np.array([[0, 0, 4]])