9
9
from sklearn ._config import config_context
10
10
from sklearn .base import BaseEstimator
11
11
from sklearn .utils ._array_api import (
12
- _ArrayAPIWrapper ,
13
12
_asarray_with_order ,
14
13
_atol_for_type ,
15
14
_average ,
@@ -104,48 +103,6 @@ def mock_getenv(key):
104
103
xp_out , is_array_api_compliant = get_namespace (X_xp )
105
104
106
105
107
- class _AdjustableNameAPITestWrapper (_ArrayAPIWrapper ):
108
- """API wrapper that has an adjustable name. Used for testing."""
109
-
110
- def __init__ (self , array_namespace , name ):
111
- super ().__init__ (array_namespace = array_namespace )
112
- self .__name__ = name
113
-
114
-
115
- def test_array_api_wrapper_astype ():
116
- """Test _ArrayAPIWrapper for ArrayAPIs that is not NumPy."""
117
- array_api_strict = pytest .importorskip ("array_api_strict" )
118
- xp_ = _AdjustableNameAPITestWrapper (array_api_strict , "array_api_strict" )
119
- xp = _ArrayAPIWrapper (xp_ )
120
-
121
- X = xp .asarray (([[1 , 2 , 3 ], [3 , 4 , 5 ]]), dtype = xp .float64 )
122
- X_converted = xp .astype (X , xp .float32 )
123
- assert X_converted .dtype == xp .float32
124
-
125
- X_converted = xp .asarray (X , dtype = xp .float32 )
126
- assert X_converted .dtype == xp .float32
127
-
128
-
129
- def test_array_api_wrapper_maximum ():
130
- """Test _ArrayAPIWrapper `maximum` for ArrayAPIs other than NumPy.
131
-
132
- This is mainly used to test for `cupy.array_api` but since that is
133
- not available on our coverage-enabled PR CI, we resort to using
134
- `array-api-strict`.
135
- """
136
- array_api_strict = pytest .importorskip ("array_api_strict" )
137
- xp_ = _AdjustableNameAPITestWrapper (array_api_strict , "array_api_strict" )
138
- xp = _ArrayAPIWrapper (xp_ )
139
-
140
- x1 = xp .asarray (([[1 , 2 , 3 ], [3 , 9 , 5 ]]), dtype = xp .int64 )
141
- x2 = xp .asarray (([[0 , 1 , 6 ], [8 , 4 , 5 ]]), dtype = xp .int64 )
142
- result = xp .asarray ([[1 , 2 , 6 ], [8 , 9 , 5 ]], dtype = xp .int64 )
143
-
144
- x_max = xp .maximum (x1 , x2 )
145
- assert x_max .dtype == x1 .dtype
146
- assert xp .all (xp .equal (x_max , result ))
147
-
148
-
149
106
@pytest .mark .parametrize ("array_api" , ["numpy" , "array_api_strict" ])
150
107
def test_asarray_with_order (array_api ):
151
108
"""Test _asarray_with_order passes along order for NumPy arrays."""
@@ -158,21 +115,6 @@ def test_asarray_with_order(array_api):
158
115
assert X_new_np .flags ["F_CONTIGUOUS" ]
159
116
160
117
161
- def test_asarray_with_order_ignored ():
162
- """Test _asarray_with_order ignores order for Generic ArrayAPI."""
163
- xp = pytest .importorskip ("array_api_strict" )
164
- xp_ = _AdjustableNameAPITestWrapper (xp , "array_api_strict" )
165
-
166
- X = numpy .asarray ([[1.2 , 3.4 , 5.1 ], [3.4 , 5.5 , 1.2 ]], order = "C" )
167
- X = xp_ .asarray (X )
168
-
169
- X_new = _asarray_with_order (X , order = "F" , xp = xp_ )
170
-
171
- X_new_np = numpy .asarray (X_new )
172
- assert X_new_np .flags ["C_CONTIGUOUS" ]
173
- assert not X_new_np .flags ["F_CONTIGUOUS" ]
<
F42D
/tr>174
-
175
-
176
118
@pytest .mark .parametrize (
177
119
"array_namespace, device_, dtype_name" , yield_namespace_device_dtype_combinations ()
178
120
)
@@ -351,8 +293,8 @@ def __init__(self, device_name):
351
293
assert array1 .device == device (array1 , array1 , array2 )
352
294
353
295
354
- # TODO: add cupy and cupy.array_api to the list of libraries once the
355
- # the following upstream issue has been fixed:
296
+ # TODO: add cupy to the list of libraries once the the following upstream issue
297
+ # has been fixed:
356
298
# https://github.com/cupy/cupy/issues/8180
357
299
@skip_if_array_api_compat_not_configured
358
300
@pytest .mark .parametrize ("library" , ["numpy" , "array_api_strict" , "torch" ])
@@ -419,7 +361,7 @@ def test_ravel(namespace, _device, _dtype):
419
361
420
362
421
363
@skip_if_array_api_compat_not_configured
422
- @pytest .mark .parametrize ("library" , ["cupy" , "torch" , "cupy.array_api" ])
364
+ @pytest .mark .parametrize ("library" , ["cupy" , "torch" ])
423
365
def test_convert_to_numpy_gpu (library ): # pragma: nocover
424
366
"""Check convert_to_numpy for GPU backed libraries."""
425
367
xp = pytest .importorskip (library )
@@ -459,7 +401,6 @@ def fit(self, X, y=None):
459
401
[
460
402
("torch" , lambda array : array .cpu ().numpy ()),
461
403
("array_api_strict" , lambda array : numpy .asarray (array )),
462
- ("cupy.array_api" , lambda array : array ._array .get ()),
463
404
],
464
405
)
465
406
def test_convert_estimator_to_ndarray (array_namespace , converter ):
@@ -500,15 +441,9 @@ def test_reshape_behavior():
500
441
xp .reshape (X , - 1 )
501
442
502
443
503
- @pytest .mark .parametrize ("wrapper" , [_ArrayAPIWrapper , _NumPyAPIWrapper ])
504
- def test_get_namespace_array_api_isdtype (wrapper ):
505
- """Test isdtype implementation from _ArrayAPIWrapper and _NumPyAPIWrapper."""
506
-
507
- if wrapper == _ArrayAPIWrapper :
508
- xp_ = pytest .importorskip ("array_api_strict" )
509
- xp = _ArrayAPIWrapper (xp_ )
510
- else :
511
- xp = _NumPyAPIWrapper ()
444
+ def test_get_namespace_array_api_isdtype ():
445
+ """Test isdtype implementation from _NumPyAPIWrapper."""
446
+ xp = _NumPyAPIWrapper ()
512
447
513
448
assert xp .isdtype (xp .float32 , xp .float32 )
514
449
assert xp .isdtype (xp .float32 , "real floating" )
@@ -533,10 +468,9 @@ def test_get_namespace_array_api_isdtype(wrapper):
533
468
534
469
assert not xp .isdtype (xp .float32 , "complex floating" )
535
470
536
- if wrapper == _NumPyAPIWrapper :
537
- assert not xp .isdtype (xp .int8 , "complex floating" )
538
- assert xp .isdtype (xp .complex64 , "complex floating" )
539
- assert xp .isdtype (xp .complex128 , "complex floating" )
471
+ assert not xp .isdtype (xp .int8 , "complex floating" )
472
+ assert xp .isdtype (xp .complex64 , "complex floating" )
473
+ assert xp .isdtype (xp .complex128 , "complex floating" )
540
474
541
475
with pytest .raises (ValueError , match = "Unrecognized data type" ):
542
476
assert xp .isdtype (xp .int16 , "unknown" )
0 commit comments