8000 ENH Array API support for confusion_matrix by StefanieSenger · Pull Request #30440 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Array API support for confusion_matrix #30440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ Metrics

- :func:`sklearn.metrics.cluster.entropy`
- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.confusion_matrix`
- :func:`sklearn.metrics.d2_tweedie_score`
- :func:`sklearn.metrics.explained_variance_score`
- :func:`sklearn.metrics.f1_score`
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/30440.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :func:`sklearn.metrics.confusion_matrix` now supports Array API compatible inputs.
by :user:`Stefanie Senger <StefanieSenger>`
63 changes: 41 additions & 22 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,21 @@
from ..utils._array_api import (
_average,
_bincount,
_convert_to_numpy,
_count_nonzero,
_find_matching_floating_dtype,
_is_numpy_namespace,
_isin,
_max_precision_float_dtype,
_nan_to_num,
_searchsorted,
_setdiff1d,
_tolist,
_union1d,
device,
get_namespace,
get_namespace_and_device,
size,
)
from ..utils._param_validation import (
Hidden,
Expand Down Expand Up @@ -275,7 +279,7 @@ def confusion_matrix(
y_pred : array-like of shape (n_samples,)
Estimated targets as returned by a classifier.

labels : array-like of shape (n_classes), default=None
labels : array-like of shape (n_classes,), default=None
List of labels to index the matrix. This may be used to reorder
or select a subset of labels.
If ``None`` is given, those that appear at least once
Expand All @@ -293,7 +297,7 @@ def confusion_matrix(

Returns
-------
C : ndarray of shape (n_classes, n_classes)
C : array of shape (n_classes, n_classes)
Confusion matrix whose i-th row and j-th
column entry indicates the number of
samples with true label being i-th class
Expand Down Expand Up @@ -338,62 +342,77 @@ def confusion_matrix(
(np.int64(0), np.int64(2), np.int64(1), np.int64(1))
"""
y_true, y_pred = attach_unique(y_true, y_pred)
xp, _, device_ = get_namespace_and_device(y_true, y_pred, labels, sample_weight)
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
if y_type not in ("binary", "multiclass"):
raise ValueError("%s is not supported" % y_type)

if labels is None:
labels = unique_labels(y_true, y_pred)
else:
labels = np.asarray(labels)
labels = xp.asarray(labels)
n_labels = labels.size
if n_labels == 0:
raise ValueError("'labels' should contains at least one label.")
raise ValueError("'labels' should contain at least one label.")
elif y_true.size == 0:
return np.zeros((n_labels, n_labels), dtype=int)
elif len(np.intersect1d(y_true, labels)) == 0:
return xp.zeros((n_labels, n_labels), dtype=xp.int64, device=device_)
elif not _isin(labels, y_true, xp=xp).any():
raise ValueError("At least one label specified must be in y_true")

if sample_weight is None:
sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
sample_weight = xp.ones(y_true.shape[0], dtype=xp.int64, device=device_)
else:
sample_weight = np.asarray(sample_weight)
sample_weight = xp.asarray(sample_weight, device=device_)

check_consistent_length(y_true, y_pred, sample_weight)

n_labels = labels.size
n_labels = size(labels)
# If labels are not consecutive integers starting from zero, then
# y_true and y_pred must be converted into index form
need_index_conversion = not (
labels.dtype.kind in {"i", "u", "b"}
and np.all(labels == np.arange(n_labels))
and y_true.min() >= 0
and y_pred.min() >= 0
xp.isdtype(labels.dtype, ("signed integer", "unsigned integer", "bool"))
and xp.all(labels == xp.arange(n_labels, device=device_))
and xp.min(y_true) >= 0
and xp.min(y_pred) >= 0
)
if need_index_conversion:
label_to_ind = {y: x for x, y in enumerate(labels)}
y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred])
y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true])
# convert 0D array into scalar type, see https://github.com/data-apis/array-api-strict/issues/109:
if xp.isdtype(labels.dtype, ("real floating")):
scalar_dtype = float
else:
scalar_dtype = str
label_to_ind = {scalar_dtype(entry): idx for idx, entry in enumerate(labels)}
y_pred = xp.asarray(
[label_to_ind.get(scalar_dtype(x), n_labels + 1) for x in y_pred],
device=device_,
)
y_true = xp.asarray(
[label_to_ind.get(scalar_dtype(x), n_labels + 1) for x in y_true],
device=device_,
)

# intersect y_pred, y_true with labels, eliminate items not in labels
ind = np.logical_and(y_pred < n_labels, y_true < n_labels)
if not np.all(ind):
ind = xp.logical_and(y_pred < n_labels, y_true < n_labels)
if not xp.all(ind):
y_pred = y_pred[ind]
y_true = y_true[ind]
# also eliminate weights of eliminated items
sample_weight = sample_weight[ind]

# Choose the accumulator dtype to always have high precision
if sample_weight.dtype.kind in {"i", "u", "b"}:
if xp.isdtype(sample_weight.dtype, ("signed integer", "unsigned integer", "bool")):
dtype = np.int64
else:
dtype = np.float64

cm = coo_matrix(
(sample_weight, (y_true, y_pred)),
(
_convert_to_numpy(sample_weight, xp=xp),
(_convert_to_numpy(y_true, xp=xp), _convert_to_numpy(y_pred, xp=xp)),
),
shape=(n_labels, n_labels),
dtype=dtype,
).toarray()
cm = xp.asarray(cm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could move this conversion to the last line (just before the return statement) and this would side-step the problems related to np.errstate-style error control not being part of the array API spec.

Since we do not move the cm back to the input device, using the array namespace for the normalization step below is not expected to yield any computational advantages.

This would also get rid of maintaining a fallback implementation of nan_to_num since this is not (yet?) part of the spec.


with np.errstate(all="ignore"):
if normalize == "true":
Expand All @@ -402,7 +421,7 @@ def confusion_matrix(
cm = cm / cm.sum(axis=0, keepdims=True)
elif normalize == "all":
cm = cm / cm.sum()
cm = np.nan_to_num(cm)
cm = _nan_to_num(cm)

if cm.shape == (1, 1):
warnings.warn(
Expand Down
21 changes: 20 additions & 1 deletion sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.stats import bernoulli

from sklearn import datasets, svm
from sklearn.base import config_context
from sklearn.datasets import make_multilabel_classification
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import (
Expand Down Expand Up @@ -39,8 +40,10 @@
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import LabelBinarizer, label_binarize
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
from sklearn.utils._mocking import MockDataFrame
from sklearn.utils._testing import (
_array_api_for_tests,
assert_allclose,
assert_almost_equal,
assert_array_almost_equal,
Expand Down Expand Up @@ -1142,7 +1145,7 @@ def test_confusion_matrix_multiclass_subset_labels():
@pytest.mark.parametrize(
"labels, err_msg",
[
([], "'labels' should contains at least one label."),
([], "'labels' should contain at least one label."),
([3, 4], "At least one label specified must be in y_true"),
],
ids=["empty list", "unknown labels"],
Expand Down Expand Up @@ -3095,3 +3098,19 @@ def test_d2_log_loss_score_raises():
err = "The labels array needs to contain at least two"
with pytest.raises(ValueError, match=err):
d2_log_loss_score(y_true, y_pred, labels=labels)


@pytest.mark.parametrize(
"array_namespace, device, _", yield_namespace_device_dtype_combinations()
)
def test_confusion_matrix_array_api(array_namespace, device, _):
"""Test that `confusion_matrix` works for all array types if need_index_conversion
evaluates to `True`and that it raises if not at least one label from `y_pred` is in
`y_true`."""
xp = _array_api_for_tests(array_namespace, device)

y_true = xp.asarray([1, 2, 3], device=device)
y_pred = xp.asarray([4, 5, 6], device=device)

with config_context(array_api_dispatch=True):
confusion_matrix(y_true, y_pred)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a check about the result to assert that the namespace of the result array is array_namespace.

The device is not necessarily the same, though (because we do not move back to the input device). I think it's fine to keep the result array on a CPU device even if the input arrays are GPU-allocated.

I don't think there is a library-agnostic way to check that a device object is a CPU device:

https://data-apis.org/array-api/latest/design_topics/device_support.html#semantics

Maybe we could check that the output

Suggested change
confusion_matrix(y_true, y_pred)
result = confusion_matrix(y_true, y_pred)
xp_result, device_result = get_namespace_and_device(result)
assert xp_result is xp
# Since the final computation always happens with NumPy / SciPy on
# the CPU, this function is expected to return an array allocated
# on the default device even when it does not match the input array's
# device.
default_device = device(xp.zeros(0))
assert device_result == default_device

If the last assertion does not work for any reason, I think it's fine not to test the result array device.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that we don't move back the data to the input arrays' device can be a bit surprising. Maybe we should document that somewhere, but I am not sure how.

Copy link
Member
@lucyleeow lucyleeow Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about a table here-ish - we could also include info on whether support is 'surface' only and we actually do all computation in numpy, like what we've decided for confusion_matrix. Though I appreciate this isn't binary and in some functions part of the compute will be array api compliant or convert to numpy only for compiled functions etc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucyleeow you may know this already but my current understanding is that the confusion_matrix PR that is more likely to be merged is #30562.

I am a bit unsure in confusion_matrix about moving back to the original array namespace. I am slightly leaning towards doing it for consistency's sake, even if I am not entirely convinced it is that useful in practice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, don't worry I saw that, just continuing the discussion here as it was started here.

4 changes: 4 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2097,6 +2097,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_multiclass_classification_metric,
check_array_api_multilabel_classification_metric,
],
confusion_matrix: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
],
f1_score: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
Expand Down
19 changes: 18 additions & 1 deletion sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ def _ravel(array, xp=None):


def _convert_to_numpy(array, xp):
"""Convert X into a NumPy ndarray on the CPU."""
"""Convert array into a NumPy ndarray on the CPU."""
xp_name = xp.__name__

if xp_name in {"array_api_compat.torch", "torch"}:
Expand Down Expand Up @@ -1108,3 +1108,20 @@ def _tolist(array, xp=None):
return array.tolist()
array_np = _convert_to_numpy(array, xp=xp)
return [element.item() for element in array_np]


def _nan_to_num(array, xp=None):
"""Substitutes NaN values of an array with 0 and inf values with the maximum or
minimum numbers available for the dtype respectively; like np.nan_to_num."""
xp, _ = get_namespace(array, xp=xp)
try:
array = xp.nan_to_num(array)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
array = xp.nan_to_num(array)
# nan_to_num is not part of the array API spec at this time but is generally
# consistently well adopted, so we anticipate a future inclusion in the spec:
# https://github.com/data-apis/array-api/issues/878
array = xp.nan_to_num(array)

except AttributeError: # currently catching exceptions from array_api_strict
array[xp.isnan(array)] = 0
if xp.isdtype(array.dtype, "real floating"):
< 9D88 /td> array[xp.isinf(array) & (array > 0)] = xp.finfo(array.dtype).max
array[xp.isinf(array) & (array < 0)] = xp.finfo(array.dtype).min
else: # xp.isdtype(array.dtype, "integral")
array[xp.isinf(array) & (array > 0)] = xp.iinfo(array.dtype).max
array[xp.isinf(array) & (array < 0)] = xp.iinfo(array.dtype).min
Comment on lines +1121 to +1126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since xp.isinf(array) is always called twice in a row, let's reuse the result of the first call.

Suggested change
if xp.isdtype(array.dtype, "real floating"):
array[xp.isinf(array) & (array > 0)] = xp.finfo(array.dtype).max
array[xp.isinf(array) & (array < 0)] = xp.finfo(array.dtype).min
else: # xp.isdtype(array.dtype, "integral")
array[xp.isinf(array) & (array > 0)] = xp.iinfo(array.dtype).max
array[xp.isinf(array) & (array < 0)] = xp.iinfo(array.dtype).min
isinf_mask = xp.isinf(array)
if xp.isdtype(array.dtype, "real floating"):
array[isinf_mask & (array > 0)] = xp.finfo(array.dtype).max
array[isinf_mask & (array < 0)] = xp.finfo(array.dtype).min
else: # xp.isdtype(array.dtype, "integral")
array[isinf_mask & (array > 0)] = xp.iinfo(array.dtype).max
array[isinf_mask & (array < 0)] = xp.iinfo(array.dtype).min

return array
Loading
0