8000 MNT Add function to generate pytest IDs for `yield_namespace_device_dtype_combinations` by lucyleeow · Pull Request #31074 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MNT Add function to generate pytest IDs for yield_namespace_device_dtype_combinations #31074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sklearn/decomposition/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sklearn.utils._array_api import (
_atol_for_type,
_convert_to_numpy,
_get_namespace_device_dtype_ids,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._array_api import device as array_device
Expand Down Expand Up @@ -1006,7 +1007,9 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp


@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize(
"check",
Expand Down Expand Up @@ -1038,7 +1041,9 @@ def test_pca_array_api_compliance(


@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize(
"check",
Expand Down
5 changes: 4 additions & 1 deletion sklearn/linear_model/tests/test_ridge.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_NUMPY_NAMESPACE_NAMES,
_atol_for_type,
_convert_to_numpy,
_get_namespace_device_dtype_ids,
yield_namespace_device_dtype_combinations,
yield_namespaces,
)
Expand Down Expand Up @@ -1256,7 +1257,9 @@ def check_array_api_attributes(name, estimator, array_namespace, device, dtype_n


@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize(
"check",
Expand Down
9 changes: 7 additions & 2 deletions sklearn/metrics/cluster/tests/test_supervised.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
)
from sklearn.metrics.cluster._supervised import _generalized_average, check_clusterings
from sklearn.utils import assert_all_finite
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
from sklearn.utils._array_api import (
_get_namespace_device_dtype_ids,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import _array_api_for_tests, assert_almost_equal

score_funcs = [
Expand Down Expand Up @@ -262,7 +265,9 @@ def test_entropy():


@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
def test_entropy_array_api(array_namespace, device, dtype_name):
xp = _array_api_for_tests(array_namespace, device)
Expand Down
5 changes: 4 additions & 1 deletion sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from sklearn.utils._array_api import (
_atol_for_type,
_convert_to_numpy,
_get_namespace_device_dtype_ids,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import (
Expand Down Expand Up @@ -2238,7 +2239,9 @@ def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers)


@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
Expand Down
9 changes: 7 additions & 2 deletions sklearn/model_selection/tests/test_search.py
67E6
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@
check_recorded_metadata,
)
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
from sklearn.utils._array_api import (
_get_namespace_device_dtype_ids,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._mocking import CheckingClassifier, MockDataFrame
from sklearn.utils._testing import (
MinimalClassifier,
Expand Down Expand Up @@ -2876,7 +2879,9 @@ def test_cv_results_multi_size_array():


@pytest.mark.parametrize(
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
"array_namespace, device, dtype",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV])
def test_array_api_search_cv_classifier(SearchCV, array_namespace, device, dtype):
Expand Down
5 changes: 4 additions & 1 deletion sklearn/m A3E2 odel_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sklearn.tests.metadata_routing_common import assert_request_is_empty
from sklearn.utils._array_api import (
_convert_to_numpy,
_get_namespace_device_dtype_ids,
get_namespace,
yield_namespace_device_dtype_combinations,
)
Expand Down Expand Up @@ -1310,7 +1311,9 @@ def test_train_test_split_default_test_size(train_size, exp_train, exp_test):


@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize(
"shuffle,stratify",
Expand Down
5 changes: 4 additions & 1 deletion sklearn/preprocessing/tests/test_data.py
F438
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sklearn.svm import SVR
from sklearn.utils import gen_batches, shuffle
from sklearn.utils._array_api import (
_get_namespace_device_dtype_ids,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids
Expand Down Expand Up @@ -689,7 +690,9 @@ def test_standard_check_array_of_inverse_transform():


@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize(
"check",
Expand Down
5 changes: 4 additions & 1 deletion sklearn/preprocessing/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from sklearn.utils._array_api import (
_convert_to_numpy,
_get_namespace_device_dtype_ids,
get_namespace,
yield_namespace_device_dtype_combinations,
)
Expand Down Expand Up @@ -707,7 +708,9 @@ def test_label_encoders_do_not_have_set_output(encoder):


@pytest.mark.parametrize(
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
"array_namespace, device, dtype",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize(
"y",
Expand Down
17 changes: 17 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff lin 10000 e number Diff line change
Expand Up @@ -105,6 +105,23 @@
yield array_namespace, None, None


def _get_namespace_device_dtype_ids(param):
"""Get pytest parametrization IDs for `yield_namespace_device_dtype_combinations`"""
# Gives clearer IDs for array-api-strict devices, see #31042 for details
try:
import array_api_strict
except ImportError:
# `None` results in the default pytest representation
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe mention the pytest doc link and mention that returning None doesn't change the default id if array-api-strict is not installed

Suggested change
return None
return

Maybe also quickly mention the fact that this gives better ids for array-api-strict devices?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned that None means default but didn't add the link as it was long(er than char limit) and wasn't sure if it is needed, but happy to add if you wish.

else:
if param == array_api_strict.Device("CPU_DEVICE"):
return "CPU_DEVICE"
if param == array_api_strict.Device("device1"):
return "device1"
if param == array_api_strict.Device("device2"):
return "device2"

Check warning on line 122 in sklearn/utils/_array_api.py

View check run for this annotation

Codecov / codecov/patch

sklearn/utils/_array_api.py#L122

Added line #L122 was not covered by tests


def _check_array_api_dispatch(array_api_dispatch):
"""Check that array_api_compat is installed and NumPy version is compatible.

Expand Down
31 changes: 24 additions & 7 deletions sklearn/utils/tests/test_array_api.py
F42D
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_count_nonzero,
_estimator_with_converted_arrays,
_fill_or_add_to_diagonal,
_get_namespace_device_dtype_ids,
_is_numpy_namespace,
_isin,
_max_precision_float_dtype,
Expand Down Expand Up @@ -113,7 +114,9 @@ def test_asarray_with_order(array_api):


@pytest.mark.parametrize(
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device_, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize(
"weights, axis, normalize, expected",
Expand Down Expand Up @@ -169,6 +172,7 @@ def test_average(
@pytest.mark.parametrize(
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(include_numpy_namespaces=False),
ids=_get_namespace_device_dtype_ids,
)
def test_average_raises_with_wrong_dtype(array_namespace, device, dtype_name):
xp = _array_api_for_tests(array_namespace, device)
Expand All @@ -194,6 +198,7 @@ def test_average_raises_with_wrong_dtype(array_namespace, device, dtype_name):
@pytest.mark.parametrize(
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(include_numpy_namespaces=True),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize(
"axis, weights, error, error_msg",
Expand Down Expand Up @@ -350,7 +355,9 @@ def test_nan_reductions(library, X, reduction, expected):


@pytest.mark.parametrize(
"namespace, _device, _dtype", yield_namespace_device_dtype_combinations()
"namespace, _device, _dtype",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
def test_ravel(namespace, _device, _dtype):
xp = _array_api_for_tests(namespace, _device)
Expand Down Expand Up @@ -437,7 +444,9 @@ def test_convert_estimator_to_array_api():


@pytest.mark.parametrize(
"namespace, _device, _dtype", yield_namespace_device_dtype_combinations()
"namespace, _device, _dtype",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
def test_indexing_dtype(namespace, _device, _dtype):
xp = _array_api_for_tests(namespace, _device)
Expand All @@ -449,7 +458,9 @@ def test_indexing_dtype(namespace, _device, _dtype):


@pytest.mark.parametrize(
"namespace, _device, _dtype", yield_namespace_device_dtype_combinations()
"namespace, _device, _dtype",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
def test_max_precision_float_dtype(namespace, _device, _dtype):
xp = _array_api_for_tests(namespace, _device)
Expand All @@ -458,7 +469,9 @@ def test_max_precision_float_dtype(namespace, _device, _dtype):


@pytest.mark.parametrize(
"ar 179B ray_namespace, device, _", yield_namespace_device_dtype_combinations()
"array_namespace, device, _",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("invert", [True, False])
@pytest.mark.parametrize("assume_unique", [True, False])
Expand Down Expand Up @@ -522,7 +535,9 @@ def test_get_namespace_and_device():


@pytest.mark.parametrize(
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device_, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
@pytest.mark.parametrize("axis", [0, 1, None, -1, -2])
Expand Down Expand Up @@ -559,7 +574,9 @@ def test_count_nonzero(


@pytest.mark.parametrize(
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device_, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("wrap", [True, False])
def test_fill_or_add_to_diagonal(array_namespace, device_, dtype_name, wrap):
Expand Down
9 changes: 7 additions & 2 deletions sklearn/utils/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import sklearn
from sklearn.externals._packaging.version import parse as parse_version
from sklearn.utils import _safe_indexing, resample, shuffle
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
from sklearn.utils._array_api import (
_get_namespace_device_dtype_ids,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._indexing import (
_determine_key_type,
_get_column_indices,
Expand Down Expand Up @@ -105,7 +108,9 @@ def test_determine_key_type_slice_error():

@skip_if_array_api_compat_not_configured
@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
def test_determine_key_type_array_api(array_namespace, device, dtype_name):
xp = _array_api_for_tests(array_namespace, device)
Expand Down
6 changes: 5 additions & 1 deletion sklearn/utils/tests/test_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from sklearn import config_context, datasets
from sklearn.model_selection import ShuffleSplit
from sklearn.svm import SVC
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
from sklearn.utils._array_api import (
_get_namespace_device_dtype_ids,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import (
_array_api_for_tests,
_convert_container,
Expand Down Expand Up @@ -382,6 +385,7 @@ def test_is_multilabel():
@pytest.mark.parametrize(
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
def test_is_multilabel_array_api_compliance(array_namespace, device, dtype_name):
xp = _array_api_for_tests(array_namespace, device)
Expand Down
9 changes: 7 additions & 2 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
check_X_y,
deprecated,
)
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
from sklearn.utils._array_api import (
_get_namespace_device_dtype_ids,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._mocking import (
MockDataFrame,
_MockEstimatorOnOffPrediction,
Expand Down Expand Up @@ -1030,7 +1033,9 @@ def test_check_consistent_length():


@pytest.mark.parametrize(
"array_namespace, device, _", yield_namespace_device_dtype_combinations()
"array_namespace, device, _",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
def test_check_consistent_length_array_api(array_namespace, device, _):
"""Test that check_consistent_length works with different array types."""
Expand Down
0