|
2 | 2 |
|
3 | 3 | import numpy
|
4 | 4 | import pytest
|
5 |
| -from numpy.testing import assert_allclose, assert_array_equal |
| 5 | +from numpy.testing import assert_allclose |
6 | 6 |
|
7 | 7 | from sklearn._config import config_context
|
8 | 8 | from sklearn.base import BaseEstimator
|
@@ -101,48 +101,6 @@ def test_array_api_wrapper_astype():
|
101 | 101 | assert X_converted.dtype == xp.float32
|
102 | 102 |
|
103 | 103 |
|
104 |
| -def test_array_api_wrapper_take_for_numpy_api(): |
105 |
| - """Test that fast path is called for numpy.array_api.""" |
106 |
| - numpy_array_api = pytest.importorskip("numpy.array_api") |
107 |
| - # USe the same name as numpy.array_api |
108 |
| - xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "numpy.array_api") |
109 |
| - xp = _ArrayAPIWrapper(xp_) |
110 |
| - |
111 |
| - X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64) |
112 |
| - X_take = xp.take(X, xp.asarray([1]), axis=0) |
113 |
| - assert hasattr(X_take, "__array_namespace__") |
114 |
| - assert_array_equal(X_take, numpy.take(X, [1], axis=0)) |
115 |
| - |
116 |
| - |
117 |
| -def test_array_api_wrapper_take(): |
118 |
| - """Test _ArrayAPIWrapper API for take.""" |
119 |
| - numpy_array_api = pytest.importorskip("numpy.array_api") |
120 |
| - xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api") |
121 |
| - xp = _ArrayAPIWrapper(xp_) |
122 |
| - |
123 |
| - # Check take compared to NumPy's with axis=0 |
124 |
| - X_1d = xp.asarray([1, 2, 3], dtype=xp.float64) |
125 |
| - X_take = xp.take(X_1d, xp.asarray([1]), axis=0) |
126 |
| - assert hasattr(X_take, "__array_namespace__") |
127 |
| - assert_array_equal(X_take, numpy.take(X_1d, [1], axis=0)) |
128 |
| - |
129 |
| - X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64) |
130 |
| - X_take = xp.take(X, xp.asarray([0]), axis=0) |
131 |
| - assert hasattr(X_take, "__array_namespace__") |
132 |
| - assert_array_equal(X_take, numpy.take(X, [0], axis=0)) |
133 |
| - |
134 |
| - # Check take compared to NumPy's with axis=1 |
135 |
| - X_take = xp.take(X, xp.asarray([0, 2]), axis=1) |
136 |
| - assert hasattr(X_take, "__array_namespace__") |
137 |
| - assert_array_equal(X_take, numpy.take(X, [0, 2], axis=1)) |
138 |
| - |
139 |
| - with pytest.raises(ValueError, match=r"Only axis in \(0, 1\) is supported"): |
140 |
| - xp.take(X, xp.asarray([0]), axis=2) |
141 |
| - |
142 |
| - with pytest.raises(ValueError, match=r"Only X.ndim in \(1, 2\) is supported"): |
143 |
| - xp.take(xp.asarray([[[0]]]), xp.asarray([0]), axis=0) |
144 |
| - |
145 |
| - |
146 | 104 | @pytest.mark.parametrize("array_api", ["numpy", "numpy.array_api"])
|
147 | 105 | def test_asarray_with_order(array_api):
|
148 | 106 | """Test _asarray_with_order passes along order for NumPy arrays."""
|
|
0 commit comments