8000 ENH Adds isdtype to Array API wrapper by thomasjpfan · Pull Request #26029 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Adds isdtype to Array API wrapper #26029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,58 @@
import scipy.special as special


def _is_numpy_namespace(xp):
"""Return True if xp is backed by NumPy."""
return xp.__name__ in {"numpy", "numpy.array_api"}


def isdtype(dtype, kind, *, xp):
"""Returns a boolean indicating whether a provided dtype is of type "kind".

Included in the v2022.12 of the Array API spec.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
"""
if isinstance(kind, tuple):
return any(_isdtype_single(dtype, k, xp=xp) for k in kind)
else:
return _isdtype_single(dtype, kind, xp=xp)


def _isdtype_single(dtype, kind, *, xp):
if isinstance(kind, str):
if kind == "bool":
return dtype == xp.bool
elif kind == "signed integer":
return dtype in {xp.int8, xp.int16, xp.int32, xp.int64}
elif kind == "unsigned integer":
return dtype in {xp.uint8, xp.uint16, xp.uint32, xp.uint64}
elif kind == "integral":
return any(
_isdtype_single(dtype, k, xp=xp)
for k in ("signed integer", "unsigned integer")
)
elif kind == "real floating":
return dtype in {xp.float32, xp.float64}
elif kind == "complex floating":
# Some name spaces do not have complex, such as cupy.array_api
# and numpy.array_api
complex_dtypes = set()
if hasattr(xp, "complex64"):
complex_dtypes.add(xp.complex64)
if hasattr(xp, "complex128"):
complex_dtypes.add(xp.complex128)
return dtype in complex_dtypes
elif kind == "numeric":
return any(
_isdtype_single(dtype, k, xp=xp)
for k in ("integral", "real floating", "complex floating")
)
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
else:
return dtype == kind


class _ArrayAPIWrapper:
"""sklearn specific Array API compatibility wrapper

Expand Down Expand Up @@ -48,6 +100,9 @@ def take(self, X, indices, *, axis):
selected = [X[:, i] for i in indices]
return self._namespace.stack(selected, axis=axis)

def isdtype(self, dtype, kind):
return isdtype(dtype, kind, xp=self._namespace)


class _NumPyApiWrapper:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we are at it, maybe rename this to _NumPyAPIWrapper to be consistent with the casing of _ArrayAPIWrapper.

Or in a follow-up PR to keep the diff easier to review for the second review.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The renaming makes sense. I opened #26039 to do the name change.

"""Array API compat wrapper for any numpy version
Expand All @@ -60,8 +115,33 @@ class _NumPyApiWrapper:
See the `get_namespace()` public function for more details.
"""

# Data types in spec
# https://data-apis.org/array-api/latest/API_specification/data_types.html
_DTYPES = {
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float32",
"float64",
"complex64",
"complex128",
}

def __getattr__(self, name):
return getattr(numpy, name)
attr = getattr(numpy, name)
# Convert to dtype objects
if name in self._DTYPES:
return numpy.dtype(attr)
return attr

@property
def bool(self):
return numpy.bool_

def astype(self, x, dtype, *, copy=True, casting="unsafe"):
# astype is not defined in the top level NumPy namespace
Expand All @@ -86,6 +166,9 @@ def unique_values(self, x):
def concat(self, arrays, *, axis=None):
return numpy.concatenate(arrays, axis=axis)

def isdtype(self, dtype, kind):
return isdtype(dtype, kind, xp=self)


def get_namespace(*arrays):
"""Get namespace of arrays.
Expand Down
2 changes: 1 addition & 1 deletion sklearn/utils/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def type_of_target(y, input_name=""):
suffix = "" # [1, 2, 3] or [[1], [2], [3]]

# Check float and contains non-integer float values
if y.dtype.kind == "f":
if xp.isdtype(y.dtype, "real floating"):
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
data = y.data if issparse(y) else y
if xp.any(data != data.astype(int)):
Expand Down
39 changes: 39 additions & 0 deletions sklearn/utils/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,42 @@ def test_convert_estimator_to_array_api():

new_est = _estimator_with_converted_arrays(est, lambda array: xp.asarray(array))
assert hasattr(new_est.X_, "__array_namespace__")


@pytest.mark.parametrize("wrapper", [_ArrayAPIWrapper, _NumPyApiWrapper])
def test_get_namespace_array_api_isdtype(wrapper):
"""Test isdtype implementation from _ArrayAPIWrapper and _NumPyApiWrapper."""

if wrapper == _ArrayAPIWrapper:
xp_ = pytest.importorskip("numpy.array_api")
xp = _ArrayAPIWrapper(xp_)
else:
xp = _NumPyApiWrapper()

assert xp.isdtype(xp.float32, xp.float32)
assert xp.isdtype(xp.float32, "real floating")
assert xp.isdtype(xp.float64, "real floating")
assert not xp.isdtype(xp.int32, "real floating")

assert xp.isdtype(xp.bool, "bool")
assert not xp.isdtype(xp.float32, "bool")

assert xp.isdtype(xp.int16, "signed integer")
assert not xp.isdtype(xp.uint32, "signed integer")

assert xp.isdtype(xp.uint16, "unsigned integer")
assert not xp.isdtype(xp.int64, "unsigned integer")

assert xp.isdtype(xp.int64, "numeric")
assert xp.isdtype(xp.float32, "numeric")
assert xp.isdtype(xp.uint32, "numeric")

assert not xp.isdtype(xp.float32, "complex floating")

if wrapper == _NumPyApiWrapper:
assert not xp.isdtype(xp.int8, "complex floating")
assert xp.isdtype(xp.complex64, "complex floating")
assert xp.isdtype(xp.complex128, "complex floating")

with pytest.raises(ValueError, match="Unrecognized data type"):
assert xp.isdtype(xp.int16, "unknown")
13 changes: 9 additions & 4 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..exceptions import DataConversionWarning
from ..utils._array_api import get_namespace
from ..utils._array_api import _asarray_with_order
from ..utils._array_api import _is_numpy_namespace
from ._isfinite import cy_isfinite, FiniteStatus

FLOAT_DTYPES = (np.float64, np.float32, np.float16)
Expand Down Expand Up @@ -111,7 +112,7 @@ def _assert_all_finite(
raise ValueError("Input contains NaN")

# We need only consider float arrays, hence can early return for all else.
if X.dtype.kind not in "fc":
if not xp.isdtype(X.dtype, ("real floating", "complex floating")):
return

# First try an O(n) time, O(1) space solution for the common case that
Expand Down Expand Up @@ -759,7 +760,7 @@ def check_array(
dtype_numeric = isinstance(dtype, str) and dtype == "numeric"

dtype_orig = getattr(array, "dtype", None)
if not hasattr(dtype_orig, "kind"):
if not is_array_api and not hasattr(dtype_orig, "kind"):
# not a data type (e.g. a column named dtype in a pandas DataFrame)
dtype_orig = None

Expand Down Expand Up @@ -832,6 +833,10 @@ def check_array(
)
)

if dtype is not None and _is_numpy_namespace(xp):
# convert to dtype object to conform to Array API to be use `xp.isdtype` later
dtype = np.dtype(dtype)

estimator_name = _check_estimator_name(estimator)
context = " by %s" % estimator_name if estimator is not None else ""

Expand Down Expand Up @@ -875,12 +880,12 @@ def check_array(
with warnings.catch_warnings():
try:
warnings.simplefilter("error", ComplexWarning)
if dtype is not None and np.dtype(dtype).kind in "iu":
if dtype is not None and xp.isdtype(dtype, "integral"):
# Conversion float -> int should not contain NaN or
# inf (numpy#14412). We cannot use casting='safe' because
# then conversion float -> int would be disallowed.
array = _asarray_with_order(array, order=order, xp=xp)
if array.dtype.kind == "f":
if xp.isdtype(array.dtype, ("real floating", "complex floating")):
_assert_all_finite(
array,
allow_nan=False,
Expand Down
0