8000 Implement binary/multiclass classification metric - Spherical Payoff by KaikeWesleyReis · Pull Request #18970 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Implement binary/multiclass classification metric - Spherical Payoff #18970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ._classification import zero_one_loss
from ._classification import brier_score_loss
from ._classification import multilabel_confusion_matrix
from ._classification import spherical_payoff_score

from . import cluster
from .cluster import adjusted_mutual_info_score
Expand Down
67 changes: 67 additions & 0 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# Saurabh Jha <saurabh.jhaa@gmail.com>
# Bernardo Stein <bernardovstein@gmail.com>
# Shangwu Yao <shangwuyao@gmail.com>
# Kaike Wesley Reis <kaikewesley@hotmail.com>
# License: BSD 3 clause


Expand Down Expand Up @@ -2504,3 +2505,69 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None):
raise
y_true = np.array(y_true == pos_label, int)
return np.average((y_true - y_prob) ** 2, weights=sample_weight)

def spherical_payoff_score(y_true, y_prob, *, sample_weight=None):
"""Compute the Spherical Payoff.
The Spherical Payoff works for binary and multiclass classification
and meausures the model's confidence to predict the correct category.
The Spherical Payoff have a defined interval: [0, 1]. Best possible
score is 1.0, and the worst is 0.0. It's calculated as the average
over all samples.

Parameters
----------
y_true : array, shape (n_samples,)
True targets.
y_prob : array, shape (n_samples,)
Probabilities of the positive class.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
-------
score : float
Spherical Payoff

Examples
--------
>>> import numpy as np
>>> from sklearn.datasets import load_iris
>>> from sklearn.metrics import spherical_payoff_score
>>> from sklearn.ensemble import RandomForestClassifier
>>> from sklearn.model_selection import train_test_split
>>> X, y = load_iris(return_X_y=True)
>>> x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.10, random_state=1206)
>>> rfc = RandomForestClassifier(random_state=1206).fit(x_train, y_train)
>>> y_prob = rfc.predict_proba(x_test)
>>> spherical_payoff_score(y_test, y_prob)
0.983...

References
----------
.. [1] `Wikipedia entry for Scoring Rule, including spherical payoff
<https://en.wikipedia.org/wiki/Scoring_rule>`
.. [2] `Netica Tutorial for scoring rules
<https://www.norsys.com/tutorials/netica/secD/tut_D2.htm>`
.. [3] `Guideline for developing and updating Bayesian Belief Networks
applied to ecological modelling and conservation
<DOI: 10.1139/x06-135>`
"""
y_true = column_or_1d(y_true)
assert_all_finite(y_true)
assert_all_finite(y_prob)
check_consistent_length(y_true, y_prob, sample_weight)

if y_prob.max() > 1:
raise ValueError("y_prob contains values greater than 1.")
if y_prob.min() < 0:
raise ValueError("y_prob contains values less than 0.")

# Loop to change the category values to default index
categories = np.unique(y_true)
for category, idx in zip(categories, np.arange(0, len(categories))):
y_true[y_true==category] = idx

# Spherical Payoff
correct_prob = y_prob[np.arange(y_prob.shape[0]), y_true]
sqrt_all_prob = np.sqrt(np.power(y_prob, 2).sum(axis=1))

return np.average(correct_prob/sqrt_all_prob, weights=sample_weight)
0