8000 ENH: Add Array API support to hamming_loss (#30838) · scikit-learn/scikit-learn@8f167d2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8f167d2

Browse files
lithomas1virchanOmarManzoor
authored
ENH: Add Array API support to hamming_loss (#30838)
Co-authored-by: Virgil Chan <virchan.math@gmail.com> Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
1 parent 774316c commit 8f167d2

File tree

4 files changed

+18
-11
lines changed

4 files changed

+18
-11
lines changed

doc/modules/array_api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ Metrics
136136
- :func:`sklearn.metrics.explained_variance_score`
137137
- :func:`sklearn.metrics.f1_score`
138138
- :func:`sklearn.metrics.fbeta_score`
139+
- :func:`sklearn.metrics.hamming_loss`
139140
- :func:`sklearn.metrics.max_error`
140141
- :func:`sklearn.metrics.mean_absolute_error`
141142
- :func:`sklearn.metrics.mean_absolute_percentage_error`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- :func:`sklearn.metrics.hamming_loss` now support Array API compatible inputs.
2+
By :user:`Thomas Li <lithomas1>`

sklearn/metrics/_classification.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
from ..utils._unique import attach_unique
5252
from ..utils.extmath import _nanaverage
5353
from ..utils.multiclass import type_of_target, unique_labels
54-
from ..utils.sparsefuncs import count_nonzero
5554
from ..utils.validation import (
5655
_check_pos_label_consistency,
5756
_check_sample_weight,
@@ -229,12 +228,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
229228
check_consistent_length(y_true, y_pred, sample_weight)
230229

231230
if y_type.startswith("multilabel"):
232-
if _is_numpy_namespace(xp):
233-
differing_labels = count_nonzero(y_true - y_pred, axis=1)
234-
else:
235-
differing_labels = _count_nonzero(
236-
y_true - y_pred, xp=xp, device=device, axis=1
237-
)
231+
differing_labels = _count_nonzero(y_true - y_pred, xp=xp, device=device, axis=1)
238232
score = xp.asarray(differing_labels == 0, device=device)
239233
else:
240234
score = y_true == y_pred
@@ -2997,15 +2991,20 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None):
29972991
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
29982992
check_consistent_length(y_true, y_pred, sample_weight)
29992993

2994+
xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight)
2995+
30002996
if sample_weight is None:
30012997
weight_average = 1.0
30022998
else:
3003-
weight_average = np.mean(sample_weight)
2999+
sample_weight = xp.asarray(sample_weight, device=device)
3000+
weight_average = _average(sample_weight, xp=xp)
30043001

30053002
if y_type.startswith("multilabel"):
3006-
n_differences = count_nonzero(y_true - y_pred, sample_weight=sample_weight)
3007-
return float(
3008-
n_differences / (y_true.shape[0] * y_true.shape[1] * weight_average)
3003+
n_differences = _count_nonzero(
3004+
y_true - y_pred, xp=xp, device=device, sample_weight=sample_weight
3005+
)
3006+
return float(n_differences) / (
3007+
y_true.shape[0] * y_true.shape[1] * weight_average
30093008
)
30103009

30113010
elif y_type in ["binary", "multiclass"]:

sklearn/metrics/tests/test_common.py

+5
Original file line numberDiff line numberDiff line change
@@ -2139,6 +2139,11 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
21392139
check_array_api_multiclass_classification_metric,
21402140
check_array_api_multilabel_classification_metric,
21412141
],
2142+
hamming_loss: [
2143+
check_array_api_binary_classification_metric,
2144+
check_array_api_multiclass_classification_metric,
2145+
check_array_api_multilabel_classification_metric,
2146+
],
21422147
mean_tweedie_deviance: [check_array_api_regression_metric],
21432148
partial(mean_tweedie_deviance, power=-0.5): [check_array_api_regression_metric],
21442149
partial(mean_tweedie_deviance, power=1.5): [check_array_api_regression_metric],

0 commit comments

Comments
 (0)
0