8000 [MRG] Improve error message with implicit pos_label in _binary_clf_curve by ogrisel · Pull Request #15562 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Improve error message with implicit pos_label in _binary_clf_curve #15562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,14 +525,23 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
sample_weight = column_or_1d(sample_weight)

# ensure binary classification if pos_label is not specified
# classes.dtype.kind in ('O', 'U', 'S') is required to avoid
# triggering a FutureWarning by calling np.array_equal(a, b)
# when elements in the two arrays are not comparable.
classes = np.unique(y_true)
if (pos_label is None and
not (np.array_equal(classes, [0, 1]) or
np.array_equal(classes, [-1, 1]) or
np.array_equal(classes, [0]) or
np.array_equal(classes, [-1]) or
np.array_equal(classes, [1]))):
raise ValueError("Data is not binary and pos_label is not specified")
if (pos_label is None and (
classes.dtype.kind in ('O', 'U', 'S') or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that we should include O?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so: object dtype is the default dtype used by pandas when parsing a CSV file with string columns that typically encode the target variable in a classification problem for instance.

not (np.array_equal(classes, [0, 1]) or
np.array_equal(classes, [-1, 1]) or
np.array_equal(classes, [0]) or
np.array_equal(classes, [-1]) or
np.array_equal(classes, [1])))):
classes_repr = ", ".join(repr(c) for c in classes)
raise ValueError("y_true takes value in {{{classes_repr}}} and "
"pos_label is not specified: either make y_true "
"take integer value in {{0, 1}} or {{-1, 1}} or "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"integer value" is not correct @ogrisel

"pass pos_label explicitly.".format(
classes_repr=classes_repr))
elif pos_label is None:
pos_label = 1.

Expand Down
45 changes: 42 additions & 3 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,14 +662,53 @@ def test_auc_score_non_binary_class():
roc_auc_score(y_true, y_pred)


def test_binary_clf_curve():
def test_binary_clf_curve_multiclass_error():
rng = check_random_state(404)
y_true = rng.randint(0, 3, size=10)
y_pred = rng.rand(10)
msg = "multiclass format is not supported"

with pytest.raises(ValueError, match=msg):
precision_recall_curve(y_true, y_pred)

with pytest.raises(ValueError, match=msg):
roc_curve(y_true, y_pred)


@pytest.mark.parametrize("curve_func", [
precision_recall_curve,
roc_curve,
])
def test_binary_clf_curve_implicit_pos_label(curve_func):
# Check that using string class labels raises an informative
# error for any supported string dtype:
msg = ("y_true takes value in {'a', 'b'} and pos_label is "
"not specified: either make y_true take integer "
"value in {0, 1} or {-1, 1} or pass pos_label "
"explicitly.")
with pytest.raises(ValueError, match=msg):
roc_curve(np.array(["a", "b"], dtype='<U1'), [0., 1.])

with pytest.raises(ValueError, match=msg):
roc_curve(np.array(["a", "b"], dtype=object), [0., 1.])

# The error message is slightly different for bytes-encoded
# class labels, but otherwise the behavior is the same:
msg = ("y_true takes value in {b'a', b'b'} and pos_label is "
"not specified: either make y_true take integer "
"value in {0, 1} or {-1, 1} or pass pos_label "
"explicitly.")
with pytest.raises(ValueError, match=msg):
roc_curve(np.array([b"a", b"b"], dtype='<S1'), [0., 1.])

# Check that it is possible to use floating point class labels
# that are interpreted similarly to integer class labels:
y_pred = [0., 1., 0.2, 0.42]
int_curve = roc_curve([0, 1, 1, 0], y_pred)
float_curve = roc_curve([0., 1., 1., 0.], y_pred)
for int_curve_part, float_curve_part in zip(int_curve, float_curve):
np.testing.assert_allclose(int_curve_part, float_curve_part)


def test_precision_recall_curve():
y_true, _, probas_pred = make_prediction(binary=True)
Expand Down Expand Up @@ -1077,8 +1116,8 @@ def check_alternative_lrap_implementation(lrap_score, n_classes=5,

# Score with ties
y_score = _sparse_random_matrix(n_components=y_true.shape[0],
4FF0 n_features=y_true.shape[1],
random_state=random_state)
n_features=y_true.shape[1],
random_state=random_state)

if hasattr(y_score, "toarray"):
y_score = y_score.toarray()
Expand Down
0