|
| 1 | +import numpy as np |
| 2 | +import pytest |
| 3 | + |
| 4 | +from sklearn.datasets import load_iris |
| 5 | +from sklearn.linear_model import LogisticRegression |
| 6 | +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor |
| 7 | + |
| 8 | +from sklearn.metrics._plot.base import _get_response |
| 9 | + |
| 10 | + |
| 11 | +@pytest.mark.parametrize( |
| 12 | + "estimator, err_msg, params", |
| 13 | + [ |
| 14 | + ( |
| 15 | + DecisionTreeRegressor(), |
| 16 | + "Expected 'estimator' to be a binary classifier", |
| 17 | + {"response_method": "auto"}, |
| 18 | + ), |
| 19 | + ( |
| 20 | + DecisionTreeClassifier(), |
| 21 | + "The class provided by 'pos_label' is unknown.", |
| 22 | + {"response_method": "auto", "pos_label": "unknown"}, |
| 23 | + ), |
| 24 | + ( |
| 25 | + DecisionTreeClassifier(), |
| 26 | + "fit on multiclass", |
| 27 | + {"response_method": "predict_proba"}, |
| 28 | + ), |
| 29 | + ], |
| 30 | +) |
| 31 | +def test_get_response_error(estimator, err_msg, params): |
| 32 | + """Check that we raise the proper error messages in `_get_response`.""" |
| 33 | + X, y = load_iris(return_X_y=True) |
| 34 | + |
| 35 | + estimator.fit(X, y) |
| 36 | + with pytest.raises(ValueError, match=err_msg): |
| 37 | + _get_response(X, estimator, **params) |
| 38 | + |
| 39 | + |
| 40 | +def test_get_response_predict_proba(): |
| 41 | + """Check the behaviour of `_get_response` using `predict_proba`.""" |
| 42 | + X, y = load_iris(return_X_y=True) |
| 43 | + X_binary, y_binary = X[:100], y[:100] |
| 44 | + |
| 45 | + classifier = DecisionTreeClassifier().fit(X_binary, y_binary) |
| 46 | + y_proba, pos_label = _get_response( |
| 47 | + X_binary, classifier, response_method="predict_proba" |
| 48 | + ) |
| 49 | + np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1]) |
| 50 | + assert pos_label == 1 |
| 51 | + |
| 52 | + y_proba, pos_label = _get_response( |
| 53 | + X_binary, classifier, response_method="predict_proba", pos_label=0 |
| 54 | + ) |
| 55 | + np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0]) |
| 56 | + assert pos_label == 0 |
| 57 | + |
| 58 | + |
| 59 | +def test_get_response_decision_function(): |
| 60 | + """Check the behaviour of `get_response` using `decision_function`.""" |
| 61 | + X, y = load_iris(return_X_y=True) |
| 62 | + X_binary, y_binary = X[:100], y[:100] |
| 63 | + |
| 64 | + classifier = LogisticRegression().fit(X_binary, y_binary) |
| 65 | + y_score, pos_label = _get_response( |
| 66 | + X_binary, classifier, response_method="decision_function" |
| 67 | + ) |
| 68 | + np.testing.assert_allclose(y_score, classifier.decision_function(X_binary)) |
| 69 | + assert pos_label == 1 |
| 70 | + |
| 71 | + y_score, pos_label = _get_response( |
| 72 | + X_binary, classifier, response_method="decision_function", pos_label=0 |
| 73 | + ) |
| 74 | + np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1) |
| 75 | + assert pos_label == 0 |
0 commit comments