8000 FEA add binary_classification_curve by SuccessMoses · Pull Request #30134 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FEA add binary_classification_curve #30134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
75a8512
Changed _binary_clf_curve to binary_clf_curve
SuccessMoses Oct 22, 2024
8b26c82
Changed binary_clf_curve to binary_classification_curve
SuccessMoses Oct 22, 2024
4e2b276
DOC Added examples for binary_classification_curve
SuccessMoses Oct 22, 2024
ad7ff13
Merge branch 'main' into feature
SuccessMoses Oct 22, 2024
97d3a92
Reformatted with black
SuccessMoses Oct 22, 2024
8f8c41c
Merge branch 'scikit-learn:main' into feature
SuccessMoses Oct 23, 2024
48c80cc
Merge branch 'feature' of https://github.com/SuccessMoses/scikit-lear…
SuccessMoses Oct 23, 2024
c6079b7
update documentation
SuccessMoses Nov 6, 2024
c37f479
update documentation
SuccessMoses Nov 6, 2024
bba7958
update documentation
SuccessMoses Nov 6, 2024
761221f
Merge branch 'main' into feature
SuccessMoses Nov 6, 2024
8c89cbe
add new api to api_reference
SuccessMoses Nov 6, 2024
f4be0b0
add new api to __init__.py
SuccessMoses Nov 6, 2024
50f1a01
add validate_parameters
SuccessMoses Nov 6, 2024
fbf0172
add changelog
SuccessMoses Nov 6, 2024
4477d6d
update changelog
SuccessMoses Nov 7, 2024
0d7ff48
fix doctest error
SuccessMoses Nov 7, 2024
ac58b10
Merge branch 'feature' of https://github.com/SuccessMoses/scikit-lear…
SuccessMoses Nov 7, 2024
26b5ab9
add -
SuccessMoses Nov 8, 2024
47baa3f
Merge branch 'main' into feature
SuccessMoses Nov 8, 2024
5b40023
fix docstring
SuccessMoses Nov 8, 2024
2bb2d4b
Merge branch 'feature' of https://github.com/SuccessMoses/scikit-lear…
SuccessMoses Nov 8, 2024
f9105e2
fix docstring
SuccessMoses Nov 8, 2024
06228bf
update changelog message
SuccessMoses Nov 16, 2024
3fd686d
Improve documentation for binary_classification_curve
SuccessMoses Nov 19, 2024
3b864be
fix doc
SuccessMoses Nov 19, 2024
3094eca
fix CI
SuccessMoses Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ def _get_submodule(module_name, submodule_name):
"auc",
"average_precision_score",
"balanced_accuracy_score",
"binary_classification_curve",
"brier_score_loss",
"class_likelihood_ratios",
"classification_report",
Expand Down
23 changes: 23 additions & 0 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ Some of these are restricted to the binary classification case:
roc_curve
class_likelihood_ratios
det_curve
binary_classification_curve


Others also work in the multiclass case:
Expand Down Expand Up @@ -674,6 +675,28 @@ false negatives and true positives as follows::
>>> tn, fp, fn, tp
(2, 1, 2, 3)

With :func:`binary_classification_curve` we can get true negatives, false positives,
false negatives and true positives for different thresholds.

>>> import numpy as np
>>> from sklearn.metrics import binary_classification_curve
>>> y_true = np.array([0., 0., 1., 1.])
>>> y_score = np.array([0.1, 0.4, 0.35, 0.8])
>>> fps, tps, thresholds = binary_classification_curve(y_true, y_score)
>>> fps
array([0., 1., 1., 2.])
>>> tps
array([1., 1., 2., 2.])
>>> thresholds
array([0.8, 0.4, 0.35, 0.1])
>>> # True Negatives can be calculated using:
>>> fps[-1] - fps
array([2., 1., 1., 0.])]
>>> # False negatives can be calculated using:
>>> tps[-1] - tps
array([1., 1., 0., 0.])


.. rubric:: Examples

* See :ref:`sphx_glr_auto_examples_model_selection_plot_confusion_matrix.py`
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.metrics/30134.api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Add :func:`metrics.binary_classification_curve` function that returns the number of
true and false positive per threshold.
By :user:`Success Moses <SuccessMoses>`
55 changes: 52 additions & 3 deletions examples/model_selection/plot_confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
================
Confusion matrix
================
==============================================================
Evaluate the performance of a classifier with Confusion Matrix
==============================================================

Example of confusion matrix usage to evaluate the quality
of the output of a classifier on the iris data set. The
Expand Down Expand Up @@ -69,3 +69,52 @@
print(disp.confusion_matrix)

plt.show()

# %%
# Binary Classification
# =====================
#
# For binary problems, :func:`sklearn.metrics.confusion_matrix` has the ``ravel`` method
# we can use get counts of true negatives, false positives, false negatives and
# true positives.
#
# :func:`sklearn.metrics.binary_classification_curve`
# can be used to count true negatives, true positives, false positives, false negatives
# for different threshold values. It is fundamental for binary classification metrics
# like :func:`sklearn.metrics.roc_auc_score` and :func:`sklearn.metrics.det_curve`.

from sklearn.datasets import make_classification
from sklearn.metrics import binary_classification_curve

X, y = make_classification(
n_samples=100,
n_features=20,
n_informative=20,
n_redundant=0,
n_classes=2,
random_state=42,
)

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)

classifier = svm.SVC(kernel="linear", C=0.01, probability=True)
classifier.fit(X_train, y_train)

y_score = classifier.predict_proba(X_test)[:, 1]

tps, fps, threshold = binary_classification_curve(y_test, y_score)

# Plot TPs and FPs vs Thresholds
plt.figure(figsize=(10, 6))

plt.plot(threshold, tps, label="True Positives (TPs)", color="blue")
plt.plot(threshold, fps, label="False Positives (FPs)", color="red")
plt.xlabel("Thresholds")
plt.ylabel("Count")
plt.title("TPs and FPs vs Thresholds")
plt.legend()
plt.grid()

plt.show()
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ._ranking import (
auc,
average_precision_score,
binary_classification_curve,
coverage_error,
dcg_score,
det_curve,
Expand Down Expand Up @@ -101,6 +102,7 @@
"auc",
"average_precision_score",
"balanced_accuracy_score",
"binary_classification_curve",
"calinski_harabasz_score",
"check_scoring",
"class_likelihood_ratios",
Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ def confusion_matrix(
ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix
given the true and predicted labels.
ConfusionMatrixDisplay : Confusion Matrix visualization.
binary_classification_curve : For binary classification, compute True Positive
and False Positive per threshold.

References
----------
Expand Down
53 changes: 48 additions & 5 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
DetCurveDisplay : DET curve visualization.
roc_curve : Compute Receiver operating characteristic (ROC) curve.
precision_recall_curve : Compute precision-recall curve.
binary_classification_curve : Compute True Positive and False Positive per
threshold.

Examples
--------
Expand All @@ -341,7 +343,7 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
>>> thresholds
array([0.35, 0.4 , 0.8 ])
"""
fps, tps, thresholds = _binary_clf_curve(
fps, tps, thresholds = binary_classification_curve(
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
)

Expand Down Expand Up @@ -774,9 +776,22 @@ def _multiclass_roc_auc_score(
)


def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
@validate_params(
{
"y_true": ["array-like"],
"y_score": ["array-like"],
"pos_label": [Real, str, "boolean", None],
"sample_weight": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def binary_classification_curve(y_true, y_score, pos_label=None, sample_weight=None):
"""Calculate true and false positives per binary classification threshold.

Read more in the :ref:`User Guide <confusion_matrix>`.

.. versionadded:: 1.6

Parameters
----------
y_true : ndarray of shape (n_samples,)
Expand Down Expand Up @@ -807,6 +822,30 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):

thresholds : ndarray of shape (n_thresholds,)
Decreasing score values.

See Also
--------
confusion_matrix : Compute classification matrix to evaluate the accuracy of a
classifier.
roc_curve : Compute Receiver operating characteristic (ROC) curve.
precision_recall_curve : Compute precision-recall curve.
det_curve : Compute Detection error tradeoff (DET) curve.
binary_classification_curve : Compute True Positive and False Positive per
threshold.

Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import binary_classification_curve
>>> y_true = np.array([0., 0., 1., 1.])
>>> y_score = np.array([0.1, 0.4, 0.35, 0.8])
>>> fps, tps, thresholds = binary_classification_curve(y_true, y_score)
>>> fps
array([0., 1., 1., 2.])
>>> tps
array([1., 1., 2., 2.])
>>> thresholds
array([0.8 , 0.4 , 0.35, 0.1 ])
"""
# Check to make sure y_true is valid
y_type = type_of_target(y_true, input_name="y_true")
Expand Down Expand Up @@ -962,6 +1001,8 @@ def precision_recall_curve(
average_precision_score : Compute average precision from prediction scores.
det_curve: Compute error rates for different probability thresholds.
roc_curve : Compute Receiver operating characteristic (ROC) curve.
binary_classification_curve : Compute True Positive and False Positive per
threshold.

Examples
--------
Expand Down Expand Up @@ -996,7 +1037,7 @@ def precision_recall_curve(
)
y_score = probas_pred

fps, tps, thresholds = _binary_clf_curve(
fps, tps, thresholds = binary_classification_curve(
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
)

Expand Down 10000 Expand Up @@ -1106,6 +1147,8 @@ def roc_curve(
(ROC) curve given the true and predicted values.
det_curve: Compute error rates for different probability thresholds.
roc_auc_score : Compute the area under the ROC curve.
binary_classification_curve : Compute True Positive and False Positive per
threshold.

Notes
-----
Expand Down Expand Up @@ -1139,7 +1182,7 @@ def roc_curve(
>>> thresholds
array([ inf, 0.8 , 0.4 , 0.35, 0.1 ])
"""
fps, tps, thresholds = _binary_clf_curve(
fps, tps, thresholds = binary_classification_curve(
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
)

Expand All @@ -1149,7 +1192,7 @@ def roc_curve(
# Here np.diff(_, 2) is used as a "second derivative" to tell if there
# is a corner at the point. Both fps and tps must be tested to handle
# thresholds with multiple data points (which are combined in
# _binary_clf_curve). This keeps all cases where the point should be kept,
# binary_classification_curve). This keeps all cases where the point should be kept,
# but does not drop more complicated cases like fps = [1, 3, 7],
# tps = [1, 2, 4]; there is no harm in keeping too many thresholds.
if drop_intermediate and len(fps) > 2:
Expand Down
8 changes: 4 additions & 4 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def test_auc_score_non_binary_class():


@pytest.mark.parametrize("curve_func", CURVE_FUNCS)
def test_binary_clf_curve_multiclass_error(curve_func):
def test_binary_classification_curve_multiclass_error(curve_func):
rng = check_random_state(404)
y_true = rng.randint(0, 3, size=10)
y_pred = rng.rand(10)
Expand All @@ -849,7 +849,7 @@ def test_binary_clf_curve_multiclass_error(curve_func):


@pytest.mark.parametrize("curve_func", CURVE_FUNCS)
def test_binary_clf_curve_implicit_pos_label(curve_func):
def test_binary_classification_curve_implicit_pos_label(curve_func):
# Check that using string class labels raises an informative
# error for any supported string dtype:
msg = (
Expand Down Expand Up @@ -877,7 +877,7 @@ def test_binary_clf_curve_implicit_pos_label(curve_func):
@pytest.mark.filterwarnings("ignore:Support for labels represented as bytes")
@pytest.mark.parametrize("curve_func", [precision_recall_curve, roc_curve])
@pytest.mark.parametrize("labels_type", ["list", "array"])
def test_binary_clf_curve_implicit_bytes_pos_label(curve_func, labels_type):
def test_binary_classification_curve_implicit_bytes_pos_label(curve_func, labels_type):
# Check that using bytes class labels raises an informative
# error for any supported string dtype:
labels = _convert_container([b"a", b"b"], labels_type)
Expand All @@ -891,7 +891,7 @@ def test_binary_clf_curve_implicit_bytes_pos_label(curve_func, labels_type):


@pytest.mark.parametrize("curve_func", CURVE_FUNCS)
def test_binary_clf_curve_zero_sample_weight(curve_func):
def test_binary_classification_curve_zero_sample_weight(curve_func):
y_true = [0, 0, 1, 1, 1]
y_score = [0.1, 0.2, 0.3, 0.4, 0.5]
sample_weight = [1, 1, 1, 0.5, 0]
Expand Down
0