|
51 | 51 | from ..utils._unique import attach_unique
|
52 | 52 | from ..utils.extmath import _nanaverage
|
53 | 53 | from ..utils.multiclass import type_of_target, unique_labels
|
54 |
| -from ..utils.sparsefuncs import count_nonzero |
55 | 54 | from ..utils.validation import (
|
56 | 55 | _check_pos_label_consistency,
|
57 | 56 | _check_sample_weight,
|
@@ -229,12 +228,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
|
229 | 228 | check_consistent_length(y_true, y_pred, sample_weight)
|
230 | 229 |
|
231 | 230 | if y_type.startswith("multilabel"):
|
232 |
| - if _is_numpy_namespace(xp): |
233 |
| - differing_labels = count_nonzero(y_true - y_pred, axis=1) |
234 |
| - else: |
235 |
| - differing_labels = _count_nonzero( |
236 |
| - y_true - y_pred, xp=xp, device=device, axis=1 |
237 |
| - ) |
| 231 | + differing_labels = _count_nonzero(y_true - y_pred, xp=xp, device=device, axis=1) |
238 | 232 | score = xp.asarray(differing_labels == 0, device=device)
|
239 | 233 | else:
|
240 | 234 | score = y_true == y_pred
|
@@ -2997,15 +2991,20 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None):
|
2997 | 2991 | y_type, y_true, y_pred = _check_targets(y_true, y_pred)
|
2998 | 2992 | check_consistent_length(y_true, y_pred, sample_weight)
|
2999 | 2993 |
|
| 2994 | + xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight) |
| 2995 | + |
3000 | 2996 | if sample_weight is None:
|
3001 | 2997 | weight_average = 1.0
|
3002 | 2998 | else:
|
3003 |
| - weight_average = np.mean(sample_weight) |
| 2999 | + sample_weight = xp.asarray(sample_weight, device=device) |
| 3000 | + weight_average = _average(sample_weight, xp=xp) |
3004 | 3001 |
|
3005 | 3002 | if y_type.startswith("multilabel"):
|
3006 |
| - n_differences = count_nonzero(y_true - y_pred, sample_weight=sample_weight) |
3007 |
| - return float( |
3008 |
| - n_differences / (y_true.shape[0] * y_true.shape[1] * weight_average) |
| 3003 | + n_differences = _count_nonzero( |
| 3004 | + y_true - y_pred, xp=xp, device=device, sample_weight=sample_weight |
| 3005 | + ) |
| 3006 | + return float(n_differences) / ( |
| 3007 | + y_true.shape[0] * y_true.shape[1] * weight_average |
3009 | 3008 | )
|
3010 | 3009 |
|
3011 | 3010 | elif y_type in ["binary", "multiclass"]:
|
|
0 commit comments