10000 MNT Add function to generate pytest IDs for `yield_namespace_device_d… · scikit-learn/scikit-learn@6778690 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6778690

Browse files
authored
MNT Add function to generate pytest IDs for yield_namespace_device_dtype_combinations (#31074)
1 parent 4f847f5 commit 6778690

File tree

13 files changed

+101
-23
lines changed

13 files changed

+101
-23
lines changed

sklearn/decomposition/tests/test_pca.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sklearn.utils._array_api import (
1616
_atol_for_type,
1717
_convert_to_numpy,
18+
_get_namespace_device_dtype_ids,
1819
yield_namespace_device_dtype_combinations,
1920
)
2021
from sklearn.utils._array_api import device as array_device
@@ -1006,7 +1007,9 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp
10061007

10071008

10081009
@pytest.mark.parametrize(
1009-
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
1010+
"array_namespace, device, dtype_name",
1011+
yield_namespace_device_dtype_combinations(),
1012+
ids=_get_namespace_device_dtype_ids,
10101013
)
10111014
@pytest.mark.parametrize(
10121015
"check",
@@ -1038,7 +1041,9 @@ def test_pca_array_api_compliance(
10381041

10391042

10401043
@pytest.mark.parametrize(
1041-
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
1044+
"array_namespace, device, dtype_name",
1045+
yield_namespace_device_dtype_combinations(),
1046+
ids=_get_namespace_device_dtype_ids,
10421047
)
10431048
@pytest.mark.parametrize(
10441049
"check",

sklearn/linear_model/tests/test_ridge.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_NUMPY_NAMESPACE_NAMES,
4646
_atol_for_type,
4747
_convert_to_numpy,
48+
_get_namespace_device_dtype_ids,
4849
yield_namespace_device_dtype_combinations,
4950
yield_namespaces,
5051
)
@@ -1256,7 +1257,9 @@ def check_array_api_attributes(name, estimator, array_namespace, device, dtype_n
12561257

12571258

12581259
@pytest.mark.parametrize(
1259-
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
1260+
"array_namespace, device, dtype_name",
1261+
yield_namespace_device_dtype_combinations(),
1262+
ids=_get_namespace_device_dtype_ids,
12601263
)
12611264
@pytest.mark.parametrize(
12621265
"check",

sklearn/metrics/cluster/tests/test_supervised.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
)
2424
from sklearn.metrics.cluster._supervised import _generalized_average, check_clusterings
2525
from sklearn.utils import assert_all_finite
26-
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
26+
from sklearn.utils._array_api import (
27+
_get_namespace_device_dtype_ids,
28+
yield_namespace_device_dtype_combinations,
29+
)
2730
from sklearn.utils._testing import _array_api_for_tests, assert_almost_equal
2831

2932
score_funcs = [
@@ -262,7 +265,9 @@ def test_entropy():
262265

263266

264267
@pytest.mark.parametrize(
265-
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
268+
"array_namespace, device, dtype_name",
269+
yield_namespace_device_dtype_combinations(),
270+
ids=_get_namespace_device_dtype_ids,
266271
)
267272
def test_entropy_array_api(array_namespace, device, dtype_name):
268273
xp = _array_api_for_tests(array_namespace, device)

sklearn/metrics/tests/test_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from sklearn.utils._array_api import (
7575
_atol_for_type,
7676
_convert_to_numpy,
77+
_get_namespace_device_dtype_ids,
7778
yield_namespace_device_dtype_combinations,
7879
)
7980
from sklearn.utils._testing import (
@@ -2238,7 +2239,9 @@ def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers)
22382239

22392240

22402241
@pytest.mark.parametrize(
2241-
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
2242+
"array_namespace, device, dtype_name",
2243+
yield_namespace_device_dtype_combinations(),
2244+
ids=_get_namespace_device_dtype_ids,
22422245
)
22432246
@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
22442247
def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):

sklearn/model_selection/tests/test_search.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@
8282
check_recorded_metadata,
8383
)
8484
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
85-
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
85+
from sklearn.utils._array_api import (
86+
_get_namespace_device_dtype_ids,
87+
yield_namespace_device_dtype_combinations,
88+
)
8689
from sklearn.utils._mocking import CheckingClassifier, MockDataFrame
8790
from sklearn.utils._testing import (
8891
MinimalClassifier,
@@ -2876,7 +2879,9 @@ def test_cv_results_multi_size_array():
28762879

28772880

28782881
@pytest.mark.parametrize(
2879-
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
2882+
"array_namespace, device, dtype",
2883+
yield_namespace_device_dtype_combinations(),
2884+
ids=_get_namespace_device_dtype_ids,
28802885
)
28812886
@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV])
28822887
def test_array_api_search_cv_classifier(SearchCV, array_namespace, device, dtype):

sklearn/model_selection/tests/test_split.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from sklearn.tests.metadata_routing_common import assert_request_is_empty
4444
from sklearn.utils._array_api import (
4545
_convert_to_numpy,
46+
_get_namespace_device_dtype_ids,
4647
get_namespace,
4748
yield_namespace_device_dtype_combinations,
4849
)
@@ -1310,7 +1311,9 @@ def test_train_test_split_default_test_size(train_size, exp_train, exp_test):
13101311

13111312

13121313
@pytest.mark.parametrize(
1313-
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
1314+
"array_namespace, device, dtype_name",
1315+
yield_namespace_device_dtype_combinations(),
1316+
ids=_get_namespace_device_dtype_ids,
13141317
)
13151318
@pytest.mark.parametrize(
13161319
"shuffle,stratify",

sklearn/preprocessing/tests/test_data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from sklearn.svm import SVR
3939
from sklearn.utils import gen_batches, shuffle
4040
from sklearn.utils._array_api import (
41+
_get_namespace_device_dtype_ids,
4142
yield_namespace_device_dtype_combinations,
4243
)
4344
from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids
@@ -689,7 +690,9 @@ def test_standard_check_array_of_inverse_transform():
689690

690691

691692
@pytest.mark.parametrize(
692-
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
693+
"array_namespace, device, dtype_name",
694+
yield_namespace_device_dtype_combinations(),
695+
ids=_get_namespace_device_dtype_ids,
693696
)
694697
@pytest.mark.parametrize(
695698
"check",

sklearn/preprocessing/tests/test_label.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from sklearn.utils._array_api import (
1515
_convert_to_numpy,
16+
_get_namespace_device_dtype_ids,
1617
get_namespace,
1718
yield_namespace_device_dtype_combinations,
1819
)
@@ -707,7 +708,9 @@ def test_label_encoders_do_not_have_set_output(encoder):
707708

708709

709710
@pytest.mark.parametrize(
710-
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
711+
"array_namespace, device, dtype",
712+
yield_namespace_device_dtype_combinations(),
713+
ids=_get_namespace_device_dtype_ids,
711714
)
712715
@pytest.mark.parametrize(
713716
"y",

sklearn/utils/_array_api.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,23 @@ def yield_namespace_device_dtype_combinations(include_numpy_namespaces=True):
105105
yield array_namespace, None, None
106106

107107

108+
def _get_namespace_device_dtype_ids(param):
109+
"""Get pytest parametrization IDs for `yield_namespace_device_dtype_combinations`"""
110+
# Gives clearer IDs for array-api-strict devices, see #31042 for details
111+
try:
112+
import array_api_strict
113+
except ImportError:
114+
# `None` results in the default pytest representation
115+
return None
116+
else:
117+
if param == array_api_strict.Device("CPU_DEVICE"):
118+
return "CPU_DEVICE"
119+
if param == array_api_strict.Device("device1"):
120+
return "device1"
121+
if param == array_api_strict.Device("device2"):
122+
return "device2"
123+
124+
108125
def _check_array_api_dispatch(array_api_dispatch):
109126
"""Check that array_api_compat is installed and NumPy version is compatible.
110127

sklearn/utils/tests/test_array_api.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_count_nonzero,
1616
_estimator_with_converted_arrays,
1717
_fill_or_add_to_diagonal,
18+
_get_namespace_device_dtype_ids,
1819
_is_numpy_namespace,
1920
_isin,
2021
_max_precision_float_dtype,
@@ -113,7 +114,9 @@ def test_asarray_with_order(array_api):
113114

114115

115116
@pytest.mark.parametrize(
116-
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
117+
"array_namespace, device_, dtype_name",
118+
yield_namespace_device_dtype_combinations(),
119+
ids=_get_namespace_device_dtype_ids,
117120
)
118121
@pytest.mark.parametrize(
119122
"weights, axis, normalize, expected",
@@ -169,6 +172,7 @@ def test_average(
169172
@pytest.mark.parametrize(
170173
"array_namespace, device, dtype_name",
171174
yield_namespace_device_dtype_combinations(include_numpy_namespaces=False),
175+
ids=_get_namespace_device_dtype_ids,
172176
)
173177
def test_average_raises_with_wrong_dtype(array_namespace, device, dtype_name):
174178
xp = _array_api_for_tests(array_namespace, device)
@@ -194,6 +198,7 @@ def test_average_raises_with_wrong_dtype(array_namespace, device, dtype_name):
194198
@pytest.mark.parametrize(
195199
"array_namespace, device, dtype_name",
196200
yield_namespace_device_dtype_combinations(include_numpy_namespaces=True),
201+
ids=_get_namespace_device_dtype_ids,
197202
)
198203
@pytest.mark.parametrize(
199204
"axis, weights, error, error_msg",
@@ -350,7 +355,9 @@ def test_nan_reductions(library, X, reduction, expected):
350355

351356

352357
@pytest.mark.parametrize(
353-
"namespace, _device, _dtype", yield_namespace_device_dtype_combinations()
358+
"namespace, _device, _dtype",
359+
yield_namespace_device_dtype_combinations(),
360+
ids=_get_namespace_device_dtype_ids,
354361
)
355362
def test_ravel(namespace, _device, _dtype):
356363
xp = _array_api_for_tests(namespace, _device)
@@ -437,7 +444,9 @@ def test_convert_estimator_to_array_api():
437444

438445

439446
@pytest.mark.parametrize(
440-
"namespace, _device, _dtype", yield_namespace_device_dtype_combinations()
447+
"namespace, _device, _dtype",
448+
yield_namespace_device_dtype_combinations(),
449+
ids=_get_namespace_device_dtype_ids,
441450
)
442451
def test_indexing_dtype(namespace, _device, _dtype):
443452
xp = _array_api_for_tests(namespace, _device)
@@ -449,7 +458,9 @@ def test_indexing_dtype(namespace, _device, _dtype):
449458

450459

451460
@pytest.mark.parametrize(
452-
"namespace, _device, _dtype", yield_namespace_device_dtype_combinations()
461+
"namespace, _device, _dtype",
462+
yield_namespace_device_dtype_combinations(),
463+
ids=_get_namespace_device_dtype_ids,
453464
)
454465
def test_max_precision_float_dtype(namespace, _device, _dtype):
455466
xp = _array_api_for_tests(namespace, _device)
@@ -458,7 +469,9 @@ def test_max_precision_float_dtype(namespace, _device, _dtype):
458469

459470

460471
@pytest.mark.parametrize(
461-
"array_namespace, device, _", yield_namespace_device_dtype_combinations()
472+
"array_namespace, device, _",
473+
yield_namespace_device_dtype_combinations(),
474+
ids=_get_namespace_device_dtype_ids,
462475
)
463476
@pytest.mark.parametrize("invert", [True, False])
464477
@pytest.mark.parametrize("assume_unique", [True, False])
@@ -522,7 +535,9 @@ def test_get_namespace_and_device():
522535

523536

524537
@pytest.mark.parametrize(
525-
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
538+
"array_namespace, device_, dtype_name",
539+
yield_namespace_device_dtype_combinations(),
540+
ids=_get_namespace_device_dtype_ids,
526541
)
527542
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
528543
@pytest.mark.parametrize("axis", [0, 1, None, -1, -2])
@@ -559,7 +574,9 @@ def test_count_nonzero(
559574

560575

561576
@pytest.mark.parametrize(
562-
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
577+
"array_namespace, device_, dtype_name",
578+
yield_namespace_device_dtype_combinations(),
579+
ids=_get_namespace_device_dtype_ids,
563580
)
564581
@pytest.mark.parametrize("wrap", [True, False])
565582
def test_fill_or_add_to_diagonal(array_namespace, device_, dtype_name, wrap):

0 commit comments

Comments
 (0)
0