From ec9bdf210d18739242f0d690d68e845ec2098500 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 14 Sep 2021 15:06:17 +0200 Subject: [PATCH 1/9] TST add unit tests for current _get_response --- sklearn/metrics/_plot/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index 8f5552ffd6808..442b833f40345 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -94,7 +94,9 @@ def _get_response(X, estimator, response_method, pos_label=None): y_pred = prediction_method(X) - if pos_label is not None and pos_label not in estimator.classes_: + # Checking that a scalar is contained in a NumPy array will raise a FutureWarning. + # We need to convert it into a list. + if pos_label is not None and pos_label not in list(estimator.classes_): raise ValueError( "The class provided by 'pos_label' is unknown. Got " f"{pos_label} instead of one of {estimator.classes_}" @@ -111,7 +113,7 @@ def _get_response(X, estimator, response_method, pos_label=None): pos_label = estimator.classes_[1] y_pred = y_pred[:, 1] else: - class_idx = np.flatnonzero(estimator.classes_ == pos_label) + class_idx = np.flatnonzero(estimator.classes_ == pos_label)[0] y_pred = y_pred[:, class_idx] else: if pos_label is None: From a989c67e185794563b9d0d51d97be96b05cfa7c0 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 14 Sep 2021 15:06:33 +0200 Subject: [PATCH 2/9] TST add unit tests for current _get_response --- sklearn/metrics/_plot/tests/test_base.py | 75 ++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 sklearn/metrics/_plot/tests/test_base.py diff --git a/sklearn/metrics/_plot/tests/test_base.py b/sklearn/metrics/_plot/tests/test_base.py new file mode 100644 index 0000000000000..2f67d7dd223f4 --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_base.py @@ -0,0 +1,75 @@ +import numpy as np +import pytest + +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor + +from sklearn.metrics._plot.base import _get_response + + +@pytest.mark.parametrize( + "estimator, err_msg, params", + [ + ( + DecisionTreeRegressor(), + "Expected 'estimator' to be a binary classifier", + {"response_method": "auto"}, + ), + ( + DecisionTreeClassifier(), + "The class provided by 'pos_label' is unknown.", + {"response_method": "auto", "pos_label": "unknown"}, + ), + ( + DecisionTreeClassifier(), + "fit on multiclass", + {"response_method": "predict_proba"}, + ), + ], +) +def test_get_response_error(estimator, err_msg, params): + """Check that we raise the proper error messages in `_get_response`.""" + X, y = load_iris(return_X_y=True) + + estimator.fit(X, y) + with pytest.raises(ValueError, match=err_msg): + _get_response(X, estimator, **params) + + +def test_get_response_predict_proba(): + """Check the behaviour of `_get_response` using `predict_proba`.""" + X, y = load_iris(return_X_y=True) + X_binary, y_binary = X[:100], y[:100] + + classifier = DecisionTreeClassifier().fit(X_binary, y_binary) + y_proba, pos_label = _get_response( + X_binary, classifier, response_method="predict_proba" + ) + np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1]) + assert pos_label == 1 + + y_proba, pos_label = _get_response( + X_binary, classifier, response_method="predict_proba", pos_label=0 + ) + np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0]) + assert pos_label == 0 + + +def test_get_response_decision_function(): + """Check the behaviour of `get_response` using `decision_function`.""" + X, y = load_iris(return_X_y=True) + X_binary, y_binary = X[:100], y[:100] + + classifier = LogisticRegression().fit(X_binary, y_binary) + y_score, pos_label = _get_response( + X_binary, classifier, response_method="decision_function" + ) + np.testing.assert_allclose(y_score, classifier.decision_function(X_binary)) + assert pos_label == 1 + + y_score, pos_label = _get_response( + X_binary, classifier, response_method="decision_function", pos_label=0 + ) + np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1) + assert pos_label == 0 From bc6efdaac3f162bd09f21840cdb47929114fad5c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 14 Sep 2021 15:20:33 +0200 Subject: [PATCH 3/9] add a proper way to check the warning raised --- setup.cfg | 2 +- sklearn/metrics/_plot/base.py | 4 ++-- sklearn/metrics/_plot/tests/test_base.py | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 050045072f428..3150bcb1ef5ad 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,7 +13,7 @@ addopts = --ignore maint_tools --ignore asv_benchmarks --doctest-modules - --disable-pytest-warnings + # --disable-pytest-warnings --color=yes -rxXs diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index 442b833f40345..817a82ee5d7ab 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -94,8 +94,8 @@ def _get_response(X, estimator, response_method, pos_label=None): y_pred = prediction_method(X) - # Checking that a scalar is contained in a NumPy array will raise a FutureWarning. - # We need to convert it into a list. + # `not in` between a `str` and a NumPy array will raise a FutureWarning; + # thus we convert the array of classes into a Python list. if pos_label is not None and pos_label not in list(estimator.classes_): raise ValueError( "The class provided by 'pos_label' is unknown. Got " diff --git a/sklearn/metrics/_plot/tests/test_base.py b/sklearn/metrics/_plot/tests/test_base.py index 2f67d7dd223f4..75dd4edd50f71 100644 --- a/sklearn/metrics/_plot/tests/test_base.py +++ b/sklearn/metrics/_plot/tests/test_base.py @@ -56,6 +56,25 @@ def test_get_response_predict_proba(): assert pos_label == 0 +def test_get_response_warning(): + """Check that we don't raise a FutureWarning issued by NumPy.""" + X, y = load_iris(return_X_y=True) + X_binary, y_binary = X[:100], y[:100] + + classifier = DecisionTreeClassifier().fit(X_binary, y_binary) + with pytest.warns(None) as record: + try: + _get_response( + X_binary, + classifier, + response_method="predict_proba", + pos_label="unknown", + ) + except ValueError: + pass + assert len(record) == 0 + + def test_get_response_decision_function(): """Check the behaviour of `get_response` using `decision_function`.""" X, y = load_iris(return_X_y=True) From cf4c2a4b8ae88981f73c4fd6bce5c84d3f02ab9d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 14 Sep 2021 15:22:43 +0200 Subject: [PATCH 4/9] revert hidding warning --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 3150bcb1ef5ad..050045072f428 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,7 +13,7 @@ addopts = --ignore maint_tools --ignore asv_benchmarks --doctest-modules - # --disable-pytest-warnings + --disable-pytest-warnings --color=yes -rxXs From bbf8413686f7b8579584ab89320015a241e0a19c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 16 Sep 2021 11:48:07 +0200 Subject: [PATCH 5/9] address ogrisel comments --- sklearn/metrics/_plot/base.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index 817a82ee5d7ab..5a3524f0061ad 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -1,5 +1,3 @@ -import numpy as np - from ...base import is_classifier @@ -101,6 +99,17 @@ def _get_response(X, estimator, response_method, pos_label=None): "The class provided by 'pos_label' is unknown. Got " f"{pos_label} instead of one of {estimator.classes_}" ) + if pos_label is not None: + try: + class_idx = estimator.classes_.tolist().index(pos_label) + except ValueError as e: + raise ValueError( + "The class provided by 'pos_label' is unknown. Got " + f"{pos_label} instead of one of {estimator.classes_}" + ) from e + else: + pos_label = estimator.classes_[1] + class_idx = 1 if y_pred.ndim != 1: # `predict_proba` y_pred_shape = y_pred.shape[1] @@ -109,16 +118,8 @@ def _get_response(X, estimator, response_method, pos_label=None): f"{classification_error} fit on multiclass ({y_pred_shape} classes)" " data" ) - if pos_label is None: - pos_label = estimator.classes_[1] - y_pred = y_pred[:, 1] - else: - class_idx = np.flatnonzero(estimator.classes_ == pos_label)[0] - y_pred = y_pred[:, class_idx] - else: - if pos_label is None: - pos_label = estimator.classes_[1] - elif pos_label == estimator.classes_[0]: - y_pred *= -1 + y_pred = y_pred[:, class_idx] + elif pos_label == estimator.classes_[0]: # `decision_function` + y_pred *= -1 return y_pred, pos_label From 71efd3ce9a039e0f996dc2da956bbe71061377ce Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 16 Sep 2021 11:50:15 +0200 Subject: [PATCH 6/9] remove useless section --- sklearn/metrics/_plot/base.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index 5a3524f0061ad..cff0526625f9a 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -89,16 +89,7 @@ def _get_response(X, estimator, response_method, pos_label=None): raise ValueError(classification_error) prediction_method = _check_classifier_response_method(estimator, response_method) - y_pred = prediction_method(X) - - # `not in` between a `str` and a NumPy array will raise a FutureWarning; - # thus we convert the array of classes into a Python list. - if pos_label is not None and pos_label not in list(estimator.classes_): - raise ValueError( - "The class provided by 'pos_label' is unknown. Got " - f"{pos_label} instead of one of {estimator.classes_}" - ) if pos_label is not None: try: class_idx = estimator.classes_.tolist().index(pos_label) From 823032acdbc3ef02c2a59991b39cdeaf6b7783ef Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 16 Sep 2021 11:52:22 +0200 Subject: [PATCH 7/9] nitpick --- sklearn/metrics/_plot/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index cff0526625f9a..82834e8833ca7 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -99,8 +99,8 @@ def _get_response(X, estimator, response_method, pos_label=None): f"{pos_label} instead of one of {estimator.classes_}" ) from e else: - pos_label = estimator.classes_[1] class_idx = 1 + pos_label = estimator.classes_[class_idx] if y_pred.ndim != 1: # `predict_proba` y_pred_shape = y_pred.shape[1] From 0294154d16da2f161bab4697b15f8026db8ac596 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 17 Sep 2021 11:12:25 +0200 Subject: [PATCH 8/9] Apply suggestions from code review Co-authored-by: Olivier Grisel --- sklearn/metrics/_plot/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index 82834e8833ca7..60377e3b10f66 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -96,7 +96,7 @@ def _get_response(X, estimator, response_method, pos_label=None): except ValueError as e: raise ValueError( "The class provided by 'pos_label' is unknown. Got " - f"{pos_label} instead of one of {estimator.classes_}" + f"{pos_label} instead of one of {set(estimator.classes_)}" ) from e else: class_idx = 1 From e1c8d33f0bfdfecffc308abe6d0b15bb37848481 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 17 Sep 2021 11:35:34 +0200 Subject: [PATCH 9/9] iter --- sklearn/metrics/_plot/tests/test_base.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_base.py b/sklearn/metrics/_plot/tests/test_base.py index 75dd4edd50f71..2f67d7dd223f4 100644 --- a/sklearn/metrics/_plot/tests/test_base.py +++ b/sklearn/metrics/_plot/tests/test_base.py @@ -56,25 +56,6 @@ def test_get_response_predict_proba(): assert pos_label == 0 -def test_get_response_warning(): - """Check that we don't raise a FutureWarning issued by NumPy.""" - X, y = load_iris(return_X_y=True) - X_binary, y_binary = X[:100], y[:100] - - classifier = DecisionTreeClassifier().fit(X_binary, y_binary) - with pytest.warns(None) as record: - try: - _get_response( - X_binary, - classifier, - response_method="predict_proba", - pos_label="unknown", - ) - except ValueError: - pass - assert len(record) == 0 - - def test_get_response_decision_function(): """Check the behaviour of `get_response` using `decision_function`.""" X, y = load_iris(return_X_y=True)