diff --git a/doc/whats_new.rst b/doc/whats_new.rst index c0c470f625f6d..a88a69c9b80b8 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -413,6 +413,11 @@ API changes summary :func:`model_selection.cross_val_predict`. :issue:`2879` by :user:`Stephen Hoover `. + - The ``decision_function`` output shape for binary classification in + :class:`multi_class.OneVsRestClassifier` and + :class:`multi_class.OneVsOneClassifier` is now ``(n_samples,)`` to conform + to scikit-learn conventions. :issue:`9100` by `Andreas Müller`_. + - Gradient boosting base models are no longer estimators. By `Andreas Müller`_. - :class:`feature_selection.SelectFromModel` now validates the ``threshold`` diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 712e8573fa469..59a17dddda538 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -368,6 +368,8 @@ def decision_function(self, X): T : array-like, shape = [n_samples, n_classes] """ check_is_fitted(self, 'estimators_') + if len(self.estimators_) == 1: + return self.estimators_[0].decision_function(X) return np.array([est.decision_function(X).ravel() for est in self.estimators_]).T @@ -574,6 +576,8 @@ def predict(self, X): Predicted multi-class targets. """ Y = self.decision_function(X) + if self.n_classes_ == 2: + return self.classes_[(Y > 0).astype(np.int)] return self.classes_[Y.argmax(axis=1)] def decision_function(self, X): @@ -606,7 +610,8 @@ def decision_function(self, X): for est, Xi in zip(self.estimators_, Xs)]).T Y = _ovr_decision_function(predictions, confidences, len(self.classes_)) - + if self.n_classes_ == 2: + return Y[:, 1] return Y @property diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index 20ec4b132fc7f..8e1c760555542 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -251,6 +251,9 @@ def conduct_test(base_clf, test_predict_proba=False): assert_equal(set(clf.classes_), classes) y_pred = clf.predict(np.array([[0, 0, 4]]))[0] assert_equal(set(y_pred), set("eggs")) + if hasattr(base_clf, 'decision_function'): + dec = clf.decision_function(X) + assert_equal(dec.shape, (5,)) if test_predict_proba: X_test = np.array([[0, 0, 4]]) @@ -524,6 +527,12 @@ def test_ovo_decision_function(): n_samples = iris.data.shape[0] ovo_clf = OneVsOneClassifier(LinearSVC(random_state=0)) + # first binary + ovo_clf.fit(iris.data, iris.target == 0) + decisions = ovo_clf.decision_function(iris.data) + assert_equal(decisions.shape, (n_samples,)) + + # then multi-class ovo_clf.fit(iris.data, iris.target) decisions = ovo_clf.decision_function(iris.data)