10000 ENH Adds isdtype to Array API wrapper (#26029) · Veghit/scikit-learn@744c5f1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 744c5f1

Browse files
thomasjpfanItay
authored andcommitted
ENH Adds isdtype to Array API wrapper (scikit-learn#26029)
1 parent 4c98fba commit 744c5f1

File tree

4 files changed

+133
-6
lines changed

4 files changed

+133
-6
lines changed

sklearn/utils/_array_api.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,58 @@
44
import scipy.special as special
55

66

7+
def _is_numpy_namespace(xp):
8+
"""Return True if xp is backed by NumPy."""
9+
return xp.__name__ in {"numpy", "numpy.array_api"}
10+
11+
12+
def isdtype(dtype, kind, *, xp):
13+
"""Returns a boolean indicating whether a provided dtype is of type "kind".
14+
15+
Included in the v2022.12 of the Array API spec.
16+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
17+
"""
18+
if isinstance(kind, tuple):
19+
return any(_isdtype_single(dtype, k, xp=xp) for k in kind)
20+
else:
21+
return _isdtype_single(dtype, kind, xp=xp)
22+
23+
24+
def _isdtype_single(dtype, kind, *, xp):
25+
if isinstance(kind, str):
26+
if kind == "bool":
27+
return dtype == xp.bool
28+
elif kind == "signed integer":
29+
return dtype in {xp.int8, xp.int16, xp.int32, xp.int64}
30+
elif kind == "unsigned integer":
31+
return dtype in {xp.uint8, xp.uint16, xp.uint32, xp.uint64}
32+
elif kind == "integral":
33+
return any(
34+
_isdtype_single(dtype, k, xp=xp)
35+
for k in ("signed integer", "unsigned integer")
36+
)
37+
elif kind == "real floating":
38+
return dtype in {xp.float32, xp.float64}
39+
elif kind == "complex floating":
40+
# Some name spaces do not have complex, such as cupy.array_api
41+
# and numpy.array_api
42+
complex_dtypes = set()
43+
if hasattr(xp, "complex64"):
44+
complex_dtypes.add(xp.complex64)
45+
if hasattr(xp, "complex128"):
46+
complex_dtypes.add(xp.complex128)
47+
return dtype in complex_dtypes
48+
elif kind == "numeric":
49+
return any(
50+
_isdtype_single(dtype, k, xp=xp)
51+
for k in ("integral", "real floating", "complex floating")
52+
)
53+
else:
54+
raise ValueError(f"Unrecognized data type kind: {kind!r}")
55+
else:
56+
return dtype == kind
57+
58+
759
class _ArrayAPIWrapper:
860
"""sklearn specific Array API compatibility wrapper
961
@@ -48,6 +100,9 @@ def take(self, X, indices, *, axis):
48100
selected = [X[:, i] for i in indices]
49101
return self._namespace.stack(selected, axis=axis)
50102

103+
def isdtype(self, dtype, kind):
104+
return isdtype(dtype, kind, xp=self._namespace)
105+
51106

52107
class _NumPyAPIWrapper:
53108
"""Array API compat wrapper for any numpy version
@@ -60,8 +115,33 @@ class _NumPyAPIWrapper:
60115
See the `get_namespace()` public function for more details.
61116
"""
62117

118+
# Data types in spec
119+
# https://data-apis.org/array-api/latest/API_specification/data_types.html
120+
_DTYPES = {
121+
"int8",
122+
"int16",
123+
"int32",
124+
"int64",
125+
"uint8",
126+
"uint16",
127+
"uint32",
128+
"uint64",
129+
"float32",
130+
"float64",
131+
"complex64",
132+
"complex128",
133+
}
134+
63135
def __getattr__(self, name):
64-
return getattr(numpy, name)
136+
attr = getattr(numpy, name)
137+
# Convert to dtype objects
138+
if name in self._DTYPES:
139+
return numpy.dtype(attr)
140+
return attr
141+
142+
@property
143+
def bool(self):
144+
return numpy.bool_
65145

66146
def astype(self, x, dtype, *, copy=True, casting="unsafe"):
67147
# astype is not defined in the top level NumPy namespace
@@ -86,6 +166,9 @@ def unique_values(self, x):
86166
def concat(self, arrays, *, axis=None):
87167
return numpy.concatenate(arrays, axis=axis)
88168

169+
def isdtype(self, dtype, kind):
170+
return isdtype(dtype, kind, xp=self)
171+
89172

90173
def get_namespace(*arrays):
91174
"""Get namespace of arrays.

sklearn/utils/multiclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def type_of_target(y, input_name=""):
374374
suffix = "" # [1, 2, 3] or [[1], [2], [3]]
375375

376376
# Check float and contains non-integer float values
377-
if y.dtype.kind == "f":
377+
if xp.isdtype(y.dtype, "real floating"):
378378
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
379379
data = y.data if issparse(y) else y
380380
if xp.any(data != data.astype(int)):

sklearn/utils/tests/test_array_api.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,42 @@ def test_convert_estimator_to_array_api():
187187

188188
new_est = _estimator_with_converted_arrays(est, lambda array: xp.asarray(array))
189189
assert hasattr(new_est.X_, "__array_namespace__")
190+
191+
192+
@pytest.mark.parametrize("wrapper", [_ArrayAPIWrapper, _NumPyApiWrapper])
193+
def test_get_namespace_array_api_isdtype(wrapper):
194+
"""Test isdtype implementation from _ArrayAPIWrapper and _NumPyApiWrapper."""
195+
196+
if wrapper == _ArrayAPIWrapper:
197+
xp_ = pytest.importorskip("numpy.array_api")
198+
xp = _ArrayAPIWrapper(xp_)
199+
else:
200+
xp = _NumPyApiWrapper()
201+
202+
assert xp.isdtype(xp.float32, xp.float32)
203+
assert xp.isdtype(xp.float32, "real floating")
204+
assert xp.isdtype(xp.float64, "real floating")
205+
assert not xp.isdtype(xp.int32, "real floating")
206+
207+
assert xp.isdtype(xp.bool, "bool")
208+
assert not xp.isdtype(xp.float32, "bool")
209+
210+
assert xp.isdtype(xp.int16, "signed integer")
211+
assert not xp.isdtype(xp.uint32, "signed integer")
212+
213+
assert xp.isdtype(xp.uint16, "unsigned integer")
214+
assert not xp.isdtype(xp.int64, "unsigned integer")
215+
216+
assert xp.isdtype(xp.int64, "numeric")
217+
assert xp.isdtype(xp.float32, "numeric")
218+
assert xp.isdtype(xp.uint32, "numeric")
219+
220+
assert not xp.isdtype(xp.float32, "complex floating")
221+
222+
if wrapper == _NumPyApiWrapper:
223+
assert not xp.isdtype(xp.int8, "complex floating")
224+
assert xp.isdtype(xp.complex64, "complex floating")
225+
assert xp.isdtype(xp.complex128, "complex floating")
226+
227+
with pytest.raises(ValueError, match="Unrecognized data type"):
228+
assert xp.isdtype(xp.int16, "unknown")

sklearn/utils/validation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ..exceptions import DataConversionWarning
3232
from ..utils._array_api import get_namespace
3333
from ..utils._array_api import _asarray_with_order
34+
from ..utils._array_api import _is_numpy_namespace
3435
from ._isfinite import cy_isfinite, FiniteStatus
3536

3637
FLOAT_DTYPES = (np.float64, np.float32, np.float16)
@@ -111,7 +112,7 @@ def _assert_all_finite(
111112
raise ValueError("Input contains NaN")
112113

113114
# We need only consider float arrays, hence can early return for all else.
114-
if X.dtype.kind not in "fc":
115+
if not xp.isdtype(X.dtype, ("real floating", "complex floating")):
115116
return
116117

117118
# First try an O(n) time, O(1) space solution for the common case that
@@ -759,7 +760,7 @@ def check_array(
759760
dtype_numeric = isinstance(dtype, str) and dtype == "numeric"
760761

761762
dtype_orig = getattr(array, "dtype", None)
762-
if not hasattr(dtype_orig, "kind"):
763+
if not is_array_api and not hasattr(dtype_orig, "kind"):
763764
# not a data type (e.g. a column named dtype in a pandas DataFrame)
764765
dtype_orig = None
765766

@@ -832,6 +833,10 @@ def check_array(
832833
)
833834
)
834835

836+
if dtype is not None and _is_numpy_namespace(xp):
837+
# convert to dtype object to conform to Array API to be use `xp.isdtype` later
838+
dtype = np.dtype(dtype)
839+
835840
estimator_name = _check_estimator_name(estimator)
836841
context = " by %s" % estimator_name if estimator is not None else ""
837842

@@ -875,12 +880,12 @@ def check_array(
875880
with warnings.catch_warnings():
876881
try:
877882
warnings.simplefilter("error", ComplexWarning)
878-
if dtype is not None and np.dtype(dtype).kind in "iu":
883+
if dtype is not None and xp.isdtype(dtype, "integral"):
879884
# Conversion float -> int should not contain NaN or
880885
# inf (numpy#14412). We cannot use casting='safe' because
881886
# then conversion float -> int would be disallowed.
882887
array = _asarray_with_order(array, order=order, xp=xp)
883-
if array.dtype.kind == "f":
888+
if xp.isdtype(array.dtype, ("real floating", "complex floating")):
884889
_assert_all_finite(
885890
array,
886891
allow_nan=False,

0 commit comments

Comments
 (0)
0