From 78d2a657f79b71fa05140b23144e2bbb71f0cb6b Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Sat, 7 Dec 2024 23:51:34 +0100 Subject: [PATCH 01/12] ENH Array API for confusion_matrix --- doc/modules/array_api.rst | 1 + sklearn/metrics/_classification.py | 76 +++++++++++++++++++--------- sklearn/metrics/tests/test_common.py | 4 ++ 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index 82eb64dec08c6..171230d64d12f 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -115,6 +115,7 @@ Metrics - :func:`sklearn.metrics.cluster.entropy` - :func:`sklearn.metrics.accuracy_score` +- :func:`sklearn.metrics.confusion_matrix` - :func:`sklearn.metrics.d2_tweedie_score` - :func:`sklearn.metrics.explained_variance_score` - :func:`sklearn.metrics.f1_score` diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index dc9252c2c9fda..7c0d60fdbc4db 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -292,7 +292,7 @@ def confusion_matrix( Returns ------- - C : ndarray of shape (n_classes, n_classes) + C : array of shape (n_classes, n_classes) Confusion matrix whose i-th row and j-th column entry indicates the number of samples with true label being i-th class @@ -337,6 +337,8 @@ def confusion_matrix( (np.int64(0), np.int64(2), np.int64(1), np.int64(1)) """ y_true, y_pred = attach_unique(y_true, y_pred) + xp, _ = get_namespace(y_true, y_pred, labels, sample_weight) + device_ = device(y_true, y_pred, labels, sample_weight) y_type, y_true, y_pred = _check_targets(y_true, y_pred) if y_type not in ("binary", "multiclass"): raise ValueError("%s is not supported" % y_type) @@ -344,56 +346,70 @@ def confusion_matrix( if labels is None: labels = unique_labels(y_true, y_pred) else: - labels = np.asarray(labels) + labels = xp.asarray(labels) n_labels = labels.size if n_labels == 0: raise ValueError("'labels' should contains at least one label.") elif y_true.size == 0: - return np.zeros((n_labels, n_labels), dtype=int) - elif len(np.intersect1d(y_true, labels)) == 0: + return xp.zeros((n_labels, n_labels), dtype=int, device=device_) + # This is not tested other than for numpy; it seems xp.isin is not existing in + # array_api_compat: + elif not xp.isin(labels, y_true).any(): raise ValueError("At least one label specified must be in y_true") if sample_weight is None: - sample_weight = np.ones(y_true.shape[0], dtype=np.int64) + sample_weight = xp.ones(y_true.shape[0], dtype=xp.int64) else: - sample_weight = np.asarray(sample_weight) + sample_weight = xp.asarray(sample_weight) check_consistent_length(y_true, y_pred, sample_weight) - n_labels = labels.size + # TODO: remove condition when torch supports the size attribute + if xp.__name__ == "array_api_compat.torch": + n_labels = xp.size(labels) + else: + n_labels = labels.size # If labels are not consecutive integers starting from zero, then # y_true and y_pred must be converted into index form need_index_conversion = not ( - labels.dtype.kind in {"i", "u", "b"} - and np.all(labels == np.arange(n_labels)) - and y_true.min() >= 0 - and y_pred.min() >= 0 + xp.isdtype(labels.dtype, ("signed integer", "unsigned integer", "bool")) + and xp.all(labels == xp.arange(n_labels, device=device_)) + and xp.min(y_true) >= 0 + and xp.min(y_pred) >= 0 ) if need_index_conversion: label_to_ind = {y: x for x, y in enumerate(labels)} - y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred]) - y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true]) + y_pred = xp.array([label_to_ind.get(x, n_labels + 1) for x in y_pred]) + y_true = xp.array([label_to_ind.get(x, n_labels + 1) for x in y_true]) # intersect y_pred, y_true with labels, eliminate items not in labels - ind = np.logical_and(y_pred < n_labels, y_true < n_labels) - if not np.all(ind): + ind = xp.logical_and(y_pred < n_labels, y_true < n_labels) + if not xp.all(ind): y_pred = y_pred[ind] y_true = y_true[ind] # also eliminate weights of eliminated items sample_weight = sample_weight[ind] # Choose the accumulator dtype to always have high precision - if sample_weight.dtype.kind in {"i", "u", "b"}: - dtype = np.int64 + if xp.isdtype(sample_weight.dtype, ("signed integer", "unsigned integer", "bool")): + dtype = xp.int64 else: - dtype = np.float64 - - cm = coo_matrix( - (sample_weight, (y_true, y_pred)), - shape=(n_labels, n_labels), - dtype=dtype, - ).toarray() + dtype = xp.float64 + + if _is_numpy_namespace(xp): + cm = coo_matrix( + (sample_weight, (y_true, y_pred)), + shape=(n_labels, n_labels), + dtype=dtype, + ).toarray() + else: + cm = xp.zeros((n_labels, n_labels), dtype=dtype) + # that is probably not very performant? + for true, pred, weight in zip(y_true, y_pred, sample_weight): + cm[true, pred] += weight + # does only numpy warn for divisions by zero or do we have to handle warnings from + # other libraries as well? with np.errstate(all="ignore"): if normalize == "true": cm = cm / cm.sum(axis=1, keepdims=True) @@ -401,7 +417,17 @@ def confusion_matrix( cm = cm / cm.sum(axis=0, keepdims=True) elif normalize == "all": cm = cm / cm.sum() - cm = np.nan_to_num(cm) + + if xp.__name__ == "array_api_strict": + cm[xp.isnan(cm)] = 0 + if isinstance(cm.dtype, float): # type checking not working properly !!!! + cm[xp.isinf(cm) & (cm > 0)] = xp.finfo(cm.dtype).max + cm[xp.isinf(cm) & (cm < 0)] = xp.finfo(cm.dtype).min + elif isinstance(cm.dtype, int): + cm[xp.isinf(cm) & (cm > 0)] = xp.iinfo(cm.dtype).max + cm[xp.isinf(cm) & (cm < 0)] = xp.iinfo(cm.dtype).min + else: + cm = xp.nan_to_num(cm) if cm.shape == (1, 1): warnings.warn( diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 0b7a47b0f12da..fcc711f1d3124 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2061,6 +2061,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) check_array_api_multiclass_classification_metric, check_array_api_multilabel_classification_metric, ], + confusion_matrix: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + ], f1_score: [ check_array_api_binary_classification_metric, check_array_api_multiclass_classification_metric, From 770e638ec2dda4f090c1a92f80178390b96fde73 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Sun, 8 Dec 2024 09:14:07 +0100 Subject: [PATCH 02/12] fix dtype checking --- sklearn/metrics/_classification.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 7c0d60fdbc4db..f5966af603188 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -351,16 +351,15 @@ def confusion_matrix( if n_labels == 0: raise ValueError("'labels' should contains at least one label.") elif y_true.size == 0: - return xp.zeros((n_labels, n_labels), dtype=int, device=device_) - # This is not tested other than for numpy; it seems xp.isin is not existing in - # array_api_compat: + return xp.zeros((n_labels, n_labels), dtype=xp.int64, device=device_) + # xp.isin is not existing in array_api_strict; not tested other than for numpy: elif not xp.isin(labels, y_true).any(): raise ValueError("At least one label specified must be in y_true") if sample_weight is None: - sample_weight = xp.ones(y_true.shape[0], dtype=xp.int64) + sample_weight = xp.ones(y_true.shape[0], dtype=xp.int64, device=device_) else: - sample_weight = xp.asarray(sample_weight) + sample_weight = xp.asarray(sample_weight, device=device_) check_consistent_length(y_true, y_pred, sample_weight) @@ -378,9 +377,14 @@ def confusion_matrix( and xp.min(y_pred) >= 0 ) if need_index_conversion: + # only tested for numpy so far: label_to_ind = {y: x for x, y in enumerate(labels)} - y_pred = xp.array([label_to_ind.get(x, n_labels + 1) for x in y_pred]) - y_true = xp.array([label_to_ind.get(x, n_labels + 1) for x in y_true]) + y_pred = xp.asarray( + [label_to_ind.get(x, n_labels + 1) for x in y_pred], device=device_ + ) + y_true = xp.asarray( + [label_to_ind.get(x, n_labels + 1) for x in y_true], device=device_ + ) # intersect y_pred, y_true with labels, eliminate items not in labels ind = xp.logical_and(y_pred < n_labels, y_true < n_labels) @@ -403,7 +407,7 @@ def confusion_matrix( dtype=dtype, ).toarray() else: - cm = xp.zeros((n_labels, n_labels), dtype=dtype) + cm = xp.zeros((n_labels, n_labels), dtype=dtype, device=device_) # that is probably not very performant? for true, pred, weight in zip(y_true, y_pred, sample_weight): cm[true, pred] += weight @@ -420,10 +424,10 @@ def confusion_matrix( if xp.__name__ == "array_api_strict": cm[xp.isnan(cm)] = 0 - if isinstance(cm.dtype, float): # type checking not working properly !!!! + if xp.isdtype(cm.dtype, "real floating"): cm[xp.isinf(cm) & (cm > 0)] = xp.finfo(cm.dtype).max cm[xp.isinf(cm) & (cm < 0)] = xp.finfo(cm.dtype).min - elif isinstance(cm.dtype, int): + else: # xp.isdtype(cm.dtype, "integral") cm[xp.isinf(cm) & (cm > 0)] = xp.iinfo(cm.dtype).max cm[xp.isinf(cm) & (cm < 0)] = xp.iinfo(cm.dtype).min else: From af440cab4e1143755edc54788f77695aba332206 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Mon, 9 Dec 2024 10:25:26 +0100 Subject: [PATCH 03/12] prepare for PR --- sklearn/metrics/_classification.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index f5966af603188..35b776685187c 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -337,8 +337,7 @@ def confusion_matrix( (np.int64(0), np.int64(2), np.int64(1), np.int64(1)) """ y_true, y_pred = attach_unique(y_true, y_pred) - xp, _ = get_namespace(y_true, y_pred, labels, sample_weight) - device_ = device(y_true, y_pred, labels, sample_weight) + xp, _, device_ = get_namespace_and_device(y_true, y_pred, labels, sample_weight) y_type, y_true, y_pred = _check_targets(y_true, y_pred) if y_type not in ("binary", "multiclass"): raise ValueError("%s is not supported" % y_type) @@ -352,7 +351,6 @@ def confusion_matrix( raise ValueError("'labels' should contains at least one label.") elif y_true.size == 0: return xp.zeros((n_labels, n_labels), dtype=xp.int64, device=device_) - # xp.isin is not existing in array_api_strict; not tested other than for numpy: elif not xp.isin(labels, y_true).any(): raise ValueError("At least one label specified must be in y_true") @@ -377,7 +375,6 @@ def confusion_matrix( and xp.min(y_pred) >= 0 ) if need_index_conversion: - # only tested for numpy so far: label_to_ind = {y: x for x, y in enumerate(labels)} y_pred = xp.asarray( [label_to_ind.get(x, n_labels + 1) for x in y_pred], device=device_ @@ -408,12 +405,9 @@ def confusion_matrix( ).toarray() else: cm = xp.zeros((n_labels, n_labels), dtype=dtype, device=device_) - # that is probably not very performant? for true, pred, weight in zip(y_true, y_pred, sample_weight): cm[true, pred] += weight - # does only numpy warn for divisions by zero or do we have to handle warnings from - # other libraries as well? with np.errstate(all="ignore"): if normalize == "true": cm = cm / cm.sum(axis=1, keepdims=True) From b45646e39b183134ed3da75213397f880575fbad Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Mon, 9 Dec 2024 10:49:51 +0100 Subject: [PATCH 04/12] change log --- doc/whats_new/upcoming_changes/array-api/30440.feature.rst | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 doc/whats_new/upcoming_changes/array-api/30440.feature.rst diff --git a/doc/whats_new/upcoming_changes/array-api/30440.feature.rst b/doc/whats_new/upcoming_changes/array-api/30440.feature.rst new file mode 100644 index 0000000000000..d1f1374f28577 --- /dev/null +++ b/doc/whats_new/upcoming_changes/array-api/30440.feature.rst @@ -0,0 +1,2 @@ +- :func:`sklearn.metrics.confusion_matrix` now supports Array API compatible inputs. + by :user:`Stefanie Senger ` From 3db7054e0a66540efd65c7dd5cf9b5977f09f43d Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Mon, 9 Dec 2024 11:50:51 +0100 Subject: [PATCH 05/12] use our _isin --- sklearn/metrics/_classification.py | 3 ++- sklearn/utils/_array_api.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 35b776685187c..14cb175d2c1a3 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -32,6 +32,7 @@ _count_nonzero, _find_matching_floating_dtype, _is_numpy_namespace, + _isin, _searchsorted, _setdiff1d, _tolist, @@ -351,7 +352,7 @@ def confusion_matrix( raise ValueError("'labels' should contains at least one label.") elif y_true.size == 0: return xp.zeros((n_labels, n_labels), dtype=xp.int64, device=device_) - elif not xp.isin(labels, y_true).any(): + elif not _isin(labels, y_true, xp=xp).any(): raise ValueError("At least one label specified must be in y_true") if sample_weight is None: diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index b2b4f88fa218f..48e7959d0fb63 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -854,7 +854,7 @@ def _ravel(array, xp=None): def _convert_to_numpy(array, xp): - """Convert X into a NumPy ndarray on the CPU.""" + """Convert array into a NumPy ndarray on the CPU.""" xp_name = xp.__name__ if xp_name in {"array_api_compat.torch", "torch"}: From abab5ead6d77d52d2c17e7820d1503ccf0fa40db Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Tue, 10 Dec 2024 14:40:35 +0100 Subject: [PATCH 06/12] changes after review --- sklearn/metrics/_classification.py | 20 ++++---------------- sklearn/utils/_array_api.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 14cb175d2c1a3..e81aadee9af12 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -40,6 +40,7 @@ device, get_namespace, get_namespace_and_device, + size, ) from ..utils._param_validation import ( Hidden, @@ -362,11 +363,7 @@ def confusion_matrix( check_consistent_length(y_true, y_pred, sample_weight) - # TODO: remove condition when torch supports the size attribute - if xp.__name__ == "array_api_compat.torch": - n_labels = xp.size(labels) - else: - n_labels = labels.size + n_labels = size(labels) # If labels are not consecutive integers starting from zero, then # y_true and y_pred must be converted into index form need_index_conversion = not ( @@ -416,17 +413,8 @@ def confusion_matrix( cm = cm / cm.sum(axis=0, keepdims=True) elif normalize == "all": cm = cm / cm.sum() - - if xp.__name__ == "array_api_strict": - cm[xp.isnan(cm)] = 0 - if xp.isdtype(cm.dtype, "real floating"): - cm[xp.isinf(cm) & (cm > 0)] = xp.finfo(cm.dtype).max - cm[xp.isinf(cm) & (cm < 0)] = xp.finfo(cm.dtype).min - else: # xp.isdtype(cm.dtype, "integral") - cm[xp.isinf(cm) & (cm > 0)] = xp.iinfo(cm.dtype).max - cm[xp.isinf(cm) & (cm < 0)] = xp.iinfo(cm.dtype).min - else: - cm = xp.nan_to_num(cm) + # cm = _nan_to_num(cm) + cm = xp.nan_to_num(cm) if cm.shape == (1, 1): warnings.warn( diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 48e7959d0fb63..acdfb5e98eb8e 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -1101,3 +1101,21 @@ def _tolist(array, xp=None): return array.tolist() array_np = _convert_to_numpy(array, xp=xp) return [element.item() for element in array_np] + + +def _nan_to_num(array, xp=None): + """Substitutes NaN values with 0 and inf values with the maximum or minimum + numbers available for the dtype respectively; like np.nan_to_num.""" + if xp is None: + xp, _ = get_namespace(array, xp=xp) + try: + array = xp.nan_to_num(array) + except AttributeError: # currently catching exceptions from array_api_strict + array[xp.isnan(array)] = 0 + if xp.isdtype(array.dtype, "real floating"): + array[xp.isinf(array) & (array > 0)] = xp.finfo(array.dtype).max + array[xp.isinf(array) & (array < 0)] = xp.finfo(array.dtype).min + else: # xp.isdtype(array.dtype, "integral") + array[xp.isinf(array) & (array > 0)] = xp.iinfo(array.dtype).max + array[xp.isinf(array) & (array < 0)] = xp.iinfo(array.dtype).min + return array From abc39818383897bef2c8c4247b582ada3c865845 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Tue, 10 Dec 2024 14:46:11 +0100 Subject: [PATCH 07/12] forgot to push that before --- sklearn/metrics/_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index e81aadee9af12..dcbb3b3db0e72 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -33,6 +33,7 @@ _find_matching_floating_dtype, _is_numpy_namespace, _isin, + _nan_to_num, _searchsorted, _setdiff1d, _tolist, @@ -413,8 +414,7 @@ def confusion_matrix( cm = cm / cm.sum(axis=0, keepdims=True) elif normalize == "all": cm = cm / cm.sum() - # cm = _nan_to_num(cm) - cm = xp.nan_to_num(cm) + cm = _nan_to_num(cm) if cm.shape == (1, 1): warnings.warn( From 09cec5d72b6286530c957f771902ba6953d30fa6 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Wed, 11 Dec 2024 15:42:10 +0100 Subject: [PATCH 08/12] add test --- sklearn/metrics/_classification.py | 4 ++-- sklearn/metrics/tests/test_classification.py | 18 ++++++++++++++++++ sklearn/utils/_array_api.py | 7 +++---- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index dcbb3b3db0e72..5df10234805c2 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -277,7 +277,7 @@ def confusion_matrix( y_pred : array-like of shape (n_samples,) Estimated targets as returned by a classifier. - labels : array-like of shape (n_classes), default=None + labels : array-like of shape (n_classes,), default=None List of labels to index the matrix. This may be used to reorder or select a subset of labels. If ``None`` is given, those that appear at least once @@ -374,7 +374,7 @@ def confusion_matrix( and xp.min(y_pred) >= 0 ) if need_index_conversion: - label_to_ind = {y: x for x, y in enumerate(labels)} + label_to_ind = {entry: idx for idx, entry in enumerate(labels)} y_pred = xp.asarray( [label_to_ind.get(x, n_labels + 1) for x in y_pred], device=device_ ) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 0e69719da1445..bb245d5b33d71 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -10,6 +10,7 @@ from scipy.stats import bernoulli from sklearn import datasets, svm +from sklearn.base import config_context from sklearn.datasets import make_multilabel_classification from sklearn.exceptions import UndefinedMetricWarning from sklearn.metrics import ( @@ -39,8 +40,10 @@ from sklearn.model_selection import cross_val_score from sklearn.preprocessing import LabelBinarizer, label_binarize from sklearn.tree import DecisionTreeClassifier +from sklearn.utils._array_api import yield_namespace_device_dtype_combinations from sklearn.utils._mocking import MockDataFrame from sklearn.utils._testing import ( + _array_api_for_tests, assert_allclose, assert_almost_equal, assert_array_almost_equal, @@ -3095,3 +3098,18 @@ def test_d2_log_loss_score_raises(): err = "The labels array needs to contain at least two" with pytest.raises(ValueError, match=err): d2_log_loss_score(y_true, y_pred, labels=labels) + + +@pytest.mark.parametrize( + "array_namespace, device, _", yield_namespace_device_dtype_combinations() +) +def test_confusion_matrix_array_api(array_namespace, device, _): + """Test that confusion_matrix works for all array types index conversion is done + and that it raises if not at least one label from `y_pred` is in `y_true`.""" + xp = _array_api_for_tests(array_namespace, device) + + y_true = xp.asarray([1, 2, 3], device=device) + y_pred = xp.asarray([4, 5, 6], device=device) + + with config_context(array_api_dispatch=True): + confusion_matrix(y_true, y_pred) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index acdfb5e98eb8e..3ff5e25f3a97e 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -1104,10 +1104,9 @@ def _tolist(array, xp=None): def _nan_to_num(array, xp=None): - """Substitutes NaN values with 0 and inf values with the maximum or minimum - numbers available for the dtype respectively; like np.nan_to_num.""" - if xp is None: - xp, _ = get_namespace(array, xp=xp) + """Substitutes NaN values of an array with 0 and inf values with the maximum or + minimum numbers available for the dtype respectively; like np.nan_to_num.""" + xp, _ = get_namespace(array, xp=xp) try: array = xp.nan_to_num(array) except AttributeError: # currently catching exceptions from array_api_strict From fdb25f6bfd55fedb52548c6ba8ae946ff220fe45 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 12 Dec 2024 01:36:20 +0100 Subject: [PATCH 09/12] fix sclar dtype --- sklearn/metrics/_classification.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 5df10234805c2..aef25578df9f0 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -374,12 +374,19 @@ def confusion_matrix( and xp.min(y_pred) >= 0 ) if need_index_conversion: - label_to_ind = {entry: idx for idx, entry in enumerate(labels)} + # convert 0D array into scalar type, see https://github.com/data-apis/array-api-strict/issues/109: + if xp.isdtype(labels.dtype, ("real floating")): + scalar_dtype = float + else: + scalar_dtype = str + label_to_ind = {scalar_dtype(entry): idx for idx, entry in enumerate(labels)} y_pred = xp.asarray( - [label_to_ind.get(x, n_labels + 1) for x in y_pred], device=device_ + [label_to_ind.get(scalar_dtype(x), n_labels + 1) for x in y_pred], + device=device_, ) y_true = xp.asarray( - [label_to_ind.get(x, n_labels + 1) for x in y_true], device=device_ + [label_to_ind.get(scalar_dtype(x), n_labels + 1) for x in y_true], + device=device_, ) # intersect y_pred, y_true with labels, eliminate items not in labels From 49f75b7d8b997f60e3c4ebdb6e84c525de12d5eb Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 12 Dec 2024 01:48:29 +0100 Subject: [PATCH 10/12] fix typos --- sklearn/metrics/_classification.py | 2 +- sklearn/metrics/tests/test_classification.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index aef25578df9f0..f9caf087cd773 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -351,7 +351,7 @@ def confusion_matrix( labels = xp.asarray(labels) n_labels = labels.size if n_labels == 0: - raise ValueError("'labels' should contains at least one label.") + raise ValueError("'labels' should contain at least one label.") elif y_true.size == 0: return xp.zeros((n_labels, n_labels), dtype=xp.int64, device=device_) elif not _isin(labels, y_true, xp=xp).any(): diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index bb245d5b33d71..a9f3b9a268864 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -1145,7 +1145,7 @@ def test_confusion_matrix_multiclass_subset_labels(): @pytest.mark.parametrize( "labels, err_msg", [ - ([], "'labels' should contains at least one label."), + ([], "'labels' should contain at least one label."), ([3, 4], "At least one label specified must be in y_true"), ], ids=["empty list", "unknown labels"], @@ -3104,8 +3104,9 @@ def test_d2_log_loss_score_raises(): "array_namespace, device, _", yield_namespace_device_dtype_combinations() ) def test_confusion_matrix_array_api(array_namespace, device, _): - """Test that confusion_matrix works for all array types index conversion is done - and that it raises if not at least one label from `y_pred` is in `y_true`.""" + """Test that `confusion_matrix` works for all array types if need_index_conversion + evaluates to `True`and that it raises if not at least one label from `y_pred` is in + `y_true`.""" xp = _array_api_for_tests(array_namespace, device) y_true = xp.asarray([1, 2, 3], device=device) From 914bb630255991e20bac3063802c863dc1b7a910 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Wed, 18 Dec 2024 12:33:54 +0100 Subject: [PATCH 11/12] convert_to_numpy and coo_matrix instead of python loop --- sklearn/metrics/_classification.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index f9caf087cd773..154e2bae38ee7 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -29,6 +29,7 @@ from ..utils._array_api import ( _average, _bincount, + _convert_to_numpy, _count_nonzero, _find_matching_floating_dtype, _is_numpy_namespace, @@ -399,20 +400,18 @@ def confusion_matrix( # Choose the accumulator dtype to always have high precision if xp.isdtype(sample_weight.dtype, ("signed integer", "unsigned integer", "bool")): - dtype = xp.int64 + dtype = np.int64 else: - dtype = xp.float64 - - if _is_numpy_namespace(xp): - cm = coo_matrix( - (sample_weight, (y_true, y_pred)), - shape=(n_labels, n_labels), - dtype=dtype, - ).toarray() - else: - cm = xp.zeros((n_labels, n_labels), dtype=dtype, device=device_) - for true, pred, weight in zip(y_true, y_pred, sample_weight): - cm[true, pred] += weight + dtype = np.float64 + cm = coo_matrix( + ( + _convert_to_numpy(sample_weight, xp=xp), + (_convert_to_numpy(y_true, xp=xp), _convert_to_numpy(y_pred, xp=xp)), + ), + shape=(n_labels, n_labels), + dtype=dtype, + ).toarray() + cm = xp.asarray(cm) with np.errstate(all="ignore"): if normalize == "true": From 78b96128b6673f01a4f58a1933c7af253db95af4 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Fri, 3 Jan 2025 09:56:14 +0100 Subject: [PATCH 12/12] satisfy ruff --- sklearn/metrics/_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 1f050c6af160b..677581f5898d4 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -34,8 +34,8 @@ _find_matching_floating_dtype, _is_numpy_namespace, _isin, - _nan_to_num, _max_precision_float_dtype, + _nan_to_num, _searchsorted, _setdiff1d, _tolist,