10000 MNT remove `take` fn in array_api wrapper (#27939) · punndcoder28/scikit-learn@8e10cd7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8e10cd7

Browse files
authored
MNT remove take fn in array_api wrapper (scikit-learn#27939)
1 parent 3b06962 commit 8e10cd7

File tree

2 files changed

+1
-67
lines changed

2 files changed

+1
-67
lines changed

sklearn/utils/_array_api.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -205,30 +205,6 @@ def __getattr__(self, name):
205205
def __eq__(self, other):
206206
return self._namespace == other._namespace
207207

208-
def take(self, X, indices, *, axis=0):
209-
# When array_api supports `take` we can use this directly
210-
# https://github.com/data-apis/array-api/issues/177
211-
if self._namespace.__name__ == "numpy.array_api":
212-
X_np = numpy.take(X, indices, axis=axis)
213-
return self._namespace.asarray(X_np)
214-
215-
# We only support axis in (0, 1) and ndim in (1, 2) because that is all we need
216-
# in scikit-learn
217-
if axis not in {0, 1}:
218-
raise ValueError(f"Only axis in (0, 1) is supported. Got {axis}")
219-
220-
if X.ndim not in {1, 2}:
221-
raise ValueError(f"Only X.ndim in (1, 2) is supported. Got {X.ndim}")
222-
223-
if axis == 0:
224-
if X.ndim == 1:
225-
selected = [X[i] for i in indices]
226-
else: # X.ndim == 2
227-
selected = [X[i, :] for i in indices]
228-
else: # axis == 1
229-
selected = [X[:, i] for i in indices]
230-
return self._namespace.stack(selected, axis=axis)
231-
232208
def isdtype(self, dtype, kind):
233209
return isdtype(dtype, kind, xp=self._namespace)
234210

sklearn/utils/tests/test_array_api.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy
44
import pytest
5-
from numpy.testing import assert_allclose, assert_array_equal
5+
from numpy.testing import assert_allclose
66

77
from sklearn._config import config_context
88
from sklearn.base import BaseEstimator
@@ -101,48 +101,6 @@ def test_array_api_wrapper_astype():
101101
assert X_converted.dtype == xp.float32
102102

103103

104-
def test_array_api_wrapper_take_for_numpy_api():
105-
"""Test that fast path is called for numpy.array_api."""
106-
numpy_array_api = pytest.importorskip("numpy.array_api")
107-
# USe the same name as numpy.array_api
108-
xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "numpy.array_api")
109-
xp = _ArrayAPIWrapper(xp_)
110-
111-
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
112-
X_take = xp.take(X, xp.asarray([1]), axis=0)
113-
assert hasattr(X_take, "__array_namespace__")
114-
assert_array_equal(X_take, numpy.take(X, [1], axis=0))
115-
116-
117-
def test_array_api_wrapper_take():
118-
"""Test _ArrayAPIWrapper API for take."""
119-
numpy_array_api = pytest.importorskip("numpy.array_api")
120-
xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api")
121-
xp = _ArrayAPIWrapper(xp_)
122-
123-
# Check take compared to NumPy's with axis=0
124-
X_1d = xp.asarray([1, 2, 3], dtype=xp.float64)
125-
X_take = xp.take(X_1d, xp.asarray([1]), axis=0)
126-
assert hasattr(X_take, "__array_namespace__")
127-
assert_array_equal(X_take, numpy.take(X_1d, [1], axis=0))
128-
129-
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
130-
X_take = xp.take(X, xp.asarray([0]), axis=0)
131-
assert hasattr(X_take, "__array_namespace__")
132-
assert_array_equal(X_take, numpy.take(X, [0], axis=0))
133-
134-
# Check take compared to NumPy's with axis=1
135-
X_take = xp.take(X, xp.asarray([0, 2]), axis=1)
136-
assert hasattr(X_take, "__array_namespace__")
137-
assert_array_equal(X_take, numpy.take(X, [0, 2], axis=1))
138-
139-
with pytest.raises(ValueError, match=r"Only axis in \(0, 1\) is supported"):
140-
xp.take(X, xp.asarray([0]), axis=2)
141-
142-
with pytest.raises(ValueError, match=r"Only X.ndim in \(1, 2\) is supported"):
143-
xp.take(xp.asarray([[[0]]]), xp.asarray([0]), axis=0)
144-
145-
146104
@pytest.mark.parametrize("array_api", ["numpy", "numpy.array_api"])
147105
def test_asarray_with_order(array_api):
148106
"""Test _asarray_with_order passes along order for NumPy arrays."""

0 commit comments

Comments
 (0)
0