-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Array API support for confusion_matrix #30440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
78d2a65
770e638
af440ca
b45646e
3db7054
abab5ea
abc3981
09cec5d
fdb25f6
49f75b7
914bb63
a939c80
3b2eb04
78b9612
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
- :func:`sklearn.metrics.confusion_matrix` now supports Array API compatible inputs. | ||
by :user:`Stefanie Senger <StefanieSenger>` |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -10,6 +10,7 @@ | |||||||||||||||||||||||
from scipy.stats import bernoulli | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from sklearn import datasets, svm | ||||||||||||||||||||||||
from sklearn.base import config_context | ||||||||||||||||||||||||
from sklearn.datasets import make_multilabel_classification | ||||||||||||||||||||||||
from sklearn.exceptions import UndefinedMetricWarning | ||||||||||||||||||||||||
from sklearn.metrics import ( | ||||||||||||||||||||||||
|
@@ -39,8 +40,10 @@ | |||||||||||||||||||||||
from sklearn.model_selection import cross_val_score | ||||||||||||||||||||||||
from sklearn.preprocessing import LabelBinarizer, label_binarize | ||||||||||||||||||||||||
from sklearn.tree import DecisionTreeClassifier | ||||||||||||||||||||||||
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations | ||||||||||||||||||||||||
from sklearn.utils._mocking import MockDataFrame | ||||||||||||||||||||||||
from sklearn.utils._testing import ( | ||||||||||||||||||||||||
_array_api_for_tests, | ||||||||||||||||||||||||
assert_allclose, | ||||||||||||||||||||||||
assert_almost_equal, | ||||||||||||||||||||||||
assert_array_almost_equal, | ||||||||||||||||||||||||
|
@@ -1142,7 +1145,7 @@ def test_confusion_matrix_multiclass_subset_labels(): | |||||||||||||||||||||||
@pytest.mark.parametrize( | ||||||||||||||||||||||||
"labels, err_msg", | ||||||||||||||||||||||||
[ | ||||||||||||||||||||||||
([], "'labels' should contains at least one label."), | ||||||||||||||||||||||||
([], "'labels' should contain at least one label."), | ||||||||||||||||||||||||
([3, 4], "At least one label specified must be in y_true"), | ||||||||||||||||||||||||
], | ||||||||||||||||||||||||
ids=["empty list", "unknown labels"], | ||||||||||||||||||||||||
|
@@ -3095,3 +3098,19 @@ def test_d2_log_loss_score_raises(): | |||||||||||||||||||||||
err = "The labels array needs to contain at least two" | ||||||||||||||||||||||||
with pytest.raises(ValueError, match=err): | ||||||||||||||||||||||||
d2_log_loss_score(y_true, y_pred, labels=labels) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
@pytest.mark.parametrize( | ||||||||||||||||||||||||
"array_namespace, device, _", yield_namespace_device_dtype_combinations() | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
def test_confusion_matrix_array_api(array_namespace, device, _): | ||||||||||||||||||||||||
"""Test that `confusion_matrix` works for all array types if need_index_conversion | ||||||||||||||||||||||||
evaluates to `True`and that it raises if not at least one label from `y_pred` is in | ||||||||||||||||||||||||
`y_true`.""" | ||||||||||||||||||||||||
xp = _array_api_for_tests(array_namespace, device) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
y_true = xp.asarray([1, 2, 3], device=device) | ||||||||||||||||||||||||
y_pred = xp.asarray([4, 5, 6], device=device) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
with config_context(array_api_dispatch=True): | ||||||||||||||||||||||||
confusion_matrix(y_true, y_pred) | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a check about the result to assert that the namespace of the result array is The device is not necessarily the same, though (because we do not move back to the input device). I think it's fine to keep the result array on a CPU device even if the input arrays are GPU-allocated. I don't think there is a library-agnostic way to check that a device object is a CPU device: https://data-apis.org/array-api/latest/design_topics/device_support.html#semantics Maybe we could check that the output
Suggested change
If the last assertion does not work for any reason, I think it's fine not to test the result array device. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fact that we don't move back the data to the input arrays' device can be a bit surprising. Maybe we should document that somewhere, but I am not sure how. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about a table here-ish - we could also include info on whether support is 'surface' only and we actually do all computation in numpy, like what we've decided for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lucyleeow you may know this already but my current understanding is that the I am a bit unsure in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, don't worry I saw that, just continuing the discussion here as it was started here. |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -861,7 +861,7 @@ def _ravel(array, xp=None): | |||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def _convert_to_numpy(array, xp): | ||||||||||||||||||||||||||||
"""Convert X into a NumPy ndarray on the CPU.""" | ||||||||||||||||||||||||||||
"""Convert array into a NumPy ndarray on the CPU.""" | ||||||||||||||||||||||||||||
xp_name = xp.__name__ | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if xp_name in {"array_api_compat.torch", "torch"}: | ||||||||||||||||||||||||||||
|
@@ -1108,3 +1108,20 @@ def _tolist(array, xp=None): | |||||||||||||||||||||||||||
return array.tolist() | ||||||||||||||||||||||||||||
array_np = _convert_to_numpy(array, xp=xp) | ||||||||||||||||||||||||||||
return [element.item() for element in array_np] | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def _nan_to_num(array, xp=None): | ||||||||||||||||||||||||||||
"""Substitutes NaN values of an array with 0 and inf values with the maximum or | ||||||||||||||||||||||||||||
minimum numbers available for the dtype respectively; like np.nan_to_num.""" | ||||||||||||||||||||||||||||
xp, _ = get_namespace(array, xp=xp) | ||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||
array = xp.nan_to_num(array) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||
except AttributeError: # currently catching exceptions from array_api_strict | ||||||||||||||||||||||||||||
array[xp.isnan(array)] = 0 | ||||||||||||||||||||||||||||
if xp.isdtype(array.dtype, "real floating"): | ||||||||||||||||||||||||||||
< 9D88 /td> | array[xp.isinf(array) & (array > 0)] = xp.finfo(array.dtype).max | |||||||||||||||||||||||||||
array[xp.isinf(array) & (array < 0)] = xp.finfo(array.dtype).min | ||||||||||||||||||||||||||||
else: # xp.isdtype(array.dtype, "integral") | ||||||||||||||||||||||||||||
array[xp.isinf(array) & (array > 0)] = xp.iinfo(array.dtype).max | ||||||||||||||||||||||||||||
array[xp.isinf(array) & (array < 0)] = xp.iinfo(array.dtype).min | ||||||||||||||||||||||||||||
Comment on lines
+1121
to
+1126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since
Suggested change
|
||||||||||||||||||||||||||||
return array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could move this conversion to the last line (just before the
return
statement) and this would side-step the problems related tonp.errstate
-style error control not being part of the array API spec.Since we do not move the cm back to the input device, using the array namespace for the normalization step below is not expected to yield any computational advantages.
This would also get rid of maintaining a fallback implementation of
nan_to_num
since this is not (yet?) part of the spec.