10000 FIX Fixes check_array nonfinite checks with ArrayAPI specification (#… · scikit-learn/scikit-learn@c069158 · GitHub
[go: up one dir, main page]

Skip to content

Commit c069158

Browse files
thomasjpfanjeremiedbb
authored andcommitted
FIX Fixes check_array nonfinite checks with ArrayAPI specification (#25619)
* FIX Fixes check_array nonfinite checks with ArrayAPI specification * DOC Adds PR number * FIX Test on both cupy and numpy
1 parent b472ad8 commit c069158

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

doc/whats_new/v1.2.rst

+7
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ Changelog
6565
when the global configuration sets `transform_output="pandas"`.
6666
:pr:`25500` by :user:`Guillaume Lemaitre <glemaitre>`.
6767

68+
:mod:`sklearn.utils`
69+
....................
70+
71+
- |Fix| Fixes a bug in :func:`utils.check_array` which now correctly performs
72+
non-finite validation with the Array API specification. :pr:`25619` by
73+
`Thomas Fan`_.
74+
6875
.. _changes_1_2_1:
6976

7077
Version 1.2.1

sklearn/utils/tests/test_validation.py

+17
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
import scipy.sparse as sp
1515

16+
from sklearn._config import config_context
1617
from sklearn.utils._testing import assert_no_warnings
1718
from sklearn.utils._testing import ignore_warnings
1819
from sklearn.utils._testing import SkipTest
@@ -1759,3 +1760,19 @@ def test_boolean_series_remains_boolean():
17591760

17601761
assert res.dtype == expected.dtype
17611762
assert_array_equal(res, expected)
1763+
1764+
1765+
@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
1766+
def test_check_array_array_api_has_non_finite(array_namespace):
1767+
"""Checks that Array API arrays checks non-finite correctly."""
1768+
xp = pytest.importorskip(array_namespace)
1769+
1770+
X_nan = xp.asarray([[xp.nan, 1, 0], [0, xp.nan, 3]], dtype=xp.float32)
1771+
with config_context(array_api_dispatch=True):
1772+
with pytest.raises(ValueError, match="Input contains NaN."):
1773+
check_array(X_nan)
1774+
1775+
X_inf = xp.asarray([[xp.inf, 1, 0], [0, xp.inf, 3]], dtype=xp.float32)
1776+
with config_context(array_api_dispatch=True):
1777+
with pytest.raises(ValueError, match="infinity or a value too large"):
1778+
check_array(X_inf)

sklearn/utils/validation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def _assert_all_finite(
131131
has_nan_error = False if allow_nan else out == FiniteStatus.has_nan
132132
has_inf = out == FiniteStatus.has_infinite
133133
else:
134-
has_inf = np.isinf(X).any()
135-
has_nan_error = False if allow_nan else xp.isnan(X).any()
134+
has_inf = xp.any(xp.isinf(X))
135+
has_nan_error = False if allow_nan else xp.any(xp.isnan(X))
136136
if has_inf or has_nan_error:
137137
if has_nan_error:
138138
type_err = "NaN"

0 commit comments

Comments
 (0)
0