8000 Add float16 to Numpy wrapper, add float dtype test · scikit-learn/scikit-learn@8e8e965 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8e8e965

Browse files
committed
Add float16 to Numpy wrapper, add float dtype test
1 parent e8e7419 commit 8e8e965

File tree

2 files changed

+14
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ class _NumPyAPIWrapper:
226226
"uint16",
227227
"uint32",
228228
"uint64",
229+
"float16",
229230
"float32",
230231
"float64",
231232
"complex64",
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
import pytest
33
from numpy.testing import assert_allclose, assert_array_equal
44

5+
from sklearn.base import BaseEstimator
6+
from sklearn.utils._array_api import get_namespace
7+
from sklearn.utils._array_api import _NumPyAPIWrapper
8+
from sklearn.utils._array_api import _ArrayAPIWrapper
9+
from sklearn.utils._array_api import _asarray_with_order
10+
from sklearn.utils._array_api import _convert_to_numpy
11+
from sklearn.utils._array_api import _estimator_with_converted_arrays
12+
from sklearn.utils._array_api import supported_float_dtypes
13+
from sklearn.utils._testing import skip_if_array_api_compat_not_configured
14+
515
from sklearn._config import config_context
616
from sklearn.base import BaseEstimator
717
from sklearn.utils._array_api import (
@@ -256,6 +266,9 @@ def test_get_namespace_array_api_isdtype(wrapper):
256266
assert xp.isdtype(xp.float64, "real floating")
257267
assert not xp.isdtype(xp.int32, "real floating")
258268

269+
for dtype in supported_float_dtypes(xp):
270+
assert xp.isdtype(dtype, "real floating")
271+
259272
assert xp.isdtype(xp.bool, "bool")
260273
assert not xp.isdtype(xp.float32, "bool")
261274