8000 FIX Apply dtype param in `check_array_api_compute_metric` unit test (… · punndcoder28/scikit-learn@77aeb82 · GitHub
[go: up one dir, main page]

Skip to content

Commit 77aeb82

Browse files
authored
FIX Apply dtype param in check_array_api_compute_metric unit test (scikit-learn#27940)
1 parent 94b8471 commit 77aeb82

File tree

8 files changed

+100
-66
lines changed

8 files changed

+100
-66
lines changed

sklearn/decomposition/tests/test_pca.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -817,9 +817,9 @@ def test_variance_correctness(copy):
817817
np.testing.assert_allclose(pca_var, true_var)
818818

819819

820-
def check_array_api_get_precision(name, estimator, array_namespace, device, dtype):
820+
def check_array_api_get_precision(name, estimator, array_namespace, device, dtype_name):
821821
xp = _array_api_for_tests(array_namespace, device)
822-
iris_np = iris.data.astype(dtype)
822+
iris_np = iris.data.astype(dtype_name)
823823
iris_xp = xp.asarray(iris_np, device=device)
824824

825825
estimator.fit(iris_np)
@@ -835,7 +835,7 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp
835835
assert_allclose(
836836
_convert_to_numpy(precision_xp, xp=xp),
837837
precision_np,
838-
atol=_atol_for_type(dtype),
838+
atol=_atol_for_type(dtype_name),
839839
)
840840
covariance_xp = estimator_xp.get_covariance()
841841
assert covariance_xp.shape == (4, 4)
@@ -844,12 +844,12 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp
844844
assert_allclose(
845845
_convert_to_numpy(covariance_xp, xp=xp),
846846
covariance_np,
847-
atol=_atol_for_type(dtype),
847+
atol=_atol_for_type(dtype_name),
848848
)
849849

850850

851851
@pytest.mark.parametrize(
852-
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
852+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
853853
)
854854
@pytest.mark.parametrize(
855855
"check",
@@ -870,13 +870,15 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp
870870
],
871871
ids=_get_check_estimator_ids,
872872
)
873-
def test_pca_array_api_compliance(estimator, check, array_namespace, device, dtype):
873+
def test_pca_array_api_compliance(
874+
estimator, check, array_namespace, device, dtype_name
875+
):
874876
name = estimator.__class__.__name__
875-
check(name, estimator, array_namespace, device=device, dtype=dtype)
877+
check(name, estimator, array_namespace, device=device, dtype_name=dtype_name)
876878

877879

878880
@pytest.mark.parametrize(
879-
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
881+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
880882
)
881883
@pytest.mark.parametrize(
882884
"check",
@@ -892,9 +894,11 @@ def test_pca_array_api_compliance(estimator, check, array_namespace, device, dty
892894
],
893895
ids=_get_check_estimator_ids,
894896
)
895-
def test_pca_mle_array_api_compliance(estimator, check, array_namespace, device, dtype):
897+
def test_pca_mle_array_api_compliance(
898+
estimator, check, array_namespace, device, dtype_name
899+
):
896900
name = estimator.__class__.__name__
897-
check(name, estimator, array_namespace, device=device, dtype=dtype)
901+
check(name, estimator, array_namespace, device=device, dtype_name=dtype_name)
898902

899903

900904
def test_array_api_error_and_warnings_on_unsupported_params():

sklearn/metrics/tests/test_common.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,67 +1733,87 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):
17331733

17341734

17351735
def check_array_api_metric(
1736-
metric, array_namespace, device, dtype, y_true_np, y_pred_np, sample_weight=None
1736+
metric, array_namespace, device, dtype_name, y_true_np, y_pred_np, sample_weight
17371737
):
17381738
xp = _array_api_for_tests(array_namespace, device)
1739+
17391740
y_true_xp = xp.asarray(y_true_np, device=device)
17401741
y_pred_xp = xp.asarray(y_pred_np, device=device)
17411742

17421743
metric_np = metric(y_true_np, y_pred_np, sample_weight=sample_weight)
17431744

1745+
if sample_weight is not None:
1746+
sample_weight = xp.asarray(sample_weight, device=device)
1747+
17441748
with config_context(array_api_dispatch=True):
1745-
if sample_weight is not None:
1746-
sample_weight = xp.asarray(sample_weight, device=device)
17471749
metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight)
17481750

17491751
assert_allclose(
17501752
metric_xp,
17511753
metric_np,
1752-
atol=_atol_for_type(dtype),
1754+
atol=_atol_for_type(dtype_name),
17531755
)
17541756

17551757

17561758
def check_array_api_binary_classification_metric(
1757-
metric, array_namespace, device, dtype
1759+
metric, array_namespace, device, dtype_name
17581760
):
17591761
y_true_np = np.array([0, 0, 1, 1])
17601762
y_pred_np = np.array([0, 1, 0, 1])
1763+
17611764
check_array_api_metric(
1762-
metric, array_namespace, device, dtype, y_true_np=y_true_np, y_pred_np=y_pred_np
1765+
metric,
1766+
array_namespace,
1767+
device,
1768+
dtype_name,
1769+
y_true_np=y_true_np,
1770+
y_pred_np=y_pred_np,
1771+
sample_weight=None,
1772+
)
1773+
1774+
sample_weight = np.array([0.0, 0.1, 2.0, 1.0], dtype=dtype_name)
1775+
1776+
check_array_api_metric(
1777+
metric,
1778+
array_namespace,
1779+
device,
1780+
dtype_name,
1781+
y_true_np=y_true_np,
1782+
y_pred_np=y_pred_np,
1783+
sample_weight=sample_weight,
17631784
)
1764-
if "sample_weight" in signature(metric).parameters:
1765-
check_array_api_metric(
1766-
metric,
1767-
array_namespace,
1768-
device,
1769-
dtype,
1770-
y_true_np=y_true_np,
1771-
y_pred_np=y_pred_np,
1772-
sample_weight=np.array([0.0, 0.1, 2.0, 1.0]),
1773-
)
17741785

17751786

17761787
def check_array_api_multiclass_classification_metric(
1777-
metric, array_namespace, device, dtype
1788+
metric, array_namespace, device, dtype_name
17781789
):
17791790
y_true_np = np.array([0, 1, 2, 3])
17801791
y_pred_np = np.array([0, 1, 0, 2])
1792+
17811793
check_array_api_metric(
1782-
metric, array_namespace, device, dtype, y_true_np=y_true_np, y_pred_np=y_pred_np
1794+
metric,
1795+
array_namespace,
1796+
device,
1797+
dtype_name,
1798+
y_true_np=y_true_np,
1799+
y_pred_np=y_pred_np,
180 F987 0+
sample_weight=None,
1801+
)
1802+
1803+
sample_weight = np.array([0.0, 0.1, 2.0, 1.0], dtype=dtype_name)
1804+
1805+
check_array_api_metric(
1806+
metric,
1807+
array_namespace,
1808+
device,
1809+
dtype_name,
1810+
y_true_np=y_true_np,
1811+
y_pred_np=y_pred_np,
1812+
sample_weight=sample_weight,
17831813
)
1784-
if "sample_weight" in signature(metric).parameters:
1785-
check_array_api_metric(
1786-
metric,
1787-
array_namespace,
1788-
device,
1789-
dtype,
1790-
y_true_np=y_true_np,
1791-
y_pred_np=y_pred_np,
1792-
sample_weight=np.array([0.0, 0.1, 2.0, 1.0]),
1793-
)
17941814

17951815

1796-
metric_checkers = {
1816+
array_api_metric_checkers = {
17971817
accuracy_score: [
17981818
check_array_api_binary_classification_metric,
17991819
check_array_api_multiclass_classification_metric,
@@ -1805,15 +1825,15 @@ def check_array_api_multiclass_classification_metric(
18051825
}
18061826

18071827

1808-
def yield_metric_checker_combinations(metric_checkers=metric_checkers):
1828+
def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers):
18091829
for metric, checkers in metric_checkers.items():
18101830
for checker in checkers:
18111831
yield metric, checker
18121832

18131833

18141834
@pytest.mark.parametrize(
1815-
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
1835+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
18161836
)
18171837
@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
1818-
def test_array_api_compliance(metric, array_namespace, device, dtype, check_func):
1819-
check_func(metric, array_namespace, device, dtype)
1838+
def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
1839+
check_func(metric, array_namespace, device, dtype_name)

sklearn/model_selection/tests/test_split.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,7 +1267,7 @@ def test_train_test_split_default_test_size(train_size, exp_train, exp_test):
12671267

12681268

12691269
@pytest.mark.parametrize(
1270-
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
1270+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
12711271
)
12721272
@pytest.mark.parametrize(
12731273
"shuffle,stratify",
@@ -1278,16 +1278,18 @@ def test_train_test_split_default_test_size(train_size, exp_train, exp_test):
12781278
(False, None),
12791279
),
12801280
)
1281-
def test_array_api_train_test_split(shuffle, stratify, array_namespace, device, dtype):
1281+
def test_array_api_train_test_split(
1282+
shuffle, stratify, array_namespace, device, dtype_name
1283+
):
12821284
xp = _array_api_for_tests(array_namespace, device)
12831285

12841286
X = np.arange(100).reshape((10, 10))
12851287
y = np.arange(10)
12861288

1287-
X_np = X.astype(dtype)
1289+
X_np = X.astype(dtype_name)
12881290
X_xp = xp.asarray(X_np, device=device)
12891291

1290-
y_np = y.astype(dtype)
1292+
y_np = y.astype(dtype_name)
12911293
y_xp = xp.asarray(y_np, device=device)
12921294

12931295
X_train_np, X_test_np, y_train_np, y_test_np = train_test_split(

sklearn/preprocessing/tests/test_data.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def test_standard_check_array_of_inverse_transform():
682682

683683

684684
@pytest.mark.parametrize(
685-
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
685+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
686686
)
687687
@pytest.mark.parametrize(
688688
"check",
@@ -701,9 +701,11 @@ def test_standard_check_array_of_inverse_transform():
701701
],
702702
ids=_get_check_estimator_ids,
703703
)
704-
def test_scaler_array_api_compliance(estimator, check, array_namespace, device, dtype):
704+
def test_scaler_array_api_compliance(
705+
estimator, check, array_namespace, device, dtype_name
706+
):
705707
name = estimator.__class__.__name__
706-
check(name, estimator, array_namespace, device=device, dtype=dtype)
708+
check(name, estimator, array_namespace, device=device, dtype_name=dtype_name)
707709

708710

709711
def test_min_max_scaler_iris():

sklearn/utils/_array_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def yield_namespace_device_dtype_combinations():
2424
The name of the device on which to allocate the arrays. Can be None to
2525
indicate that the default value should be used.
2626
27-
dtype : str
27+
dtype_name : str
2828
The name of the data type to use for arrays. Can be None to indicate
2929
that the default value should be used.
3030
"""
@@ -444,7 +444,9 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
444444
sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64)
445445

446446
if sample_weight is not None:
447-
sample_weight = xp.asarray(sample_weight, dtype=sample_score.dtype)
447+
sample_weight = xp.asarray(
448+
sample_weight, dtype=sample_score.dtype, device=device(sample_score)
449+
)
448450
if not xp.isdtype(sample_weight.dtype, "real floating"):
449451
sample_weight = xp.astype(sample_weight, xp.float64)
450452

sklearn/utils/estimator_checks.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,15 @@ def _yield_outliers_checks(estimator):
311311

312312

313313
def _yield_array_api_checks(estimator):
314-
for array_namespace, device, dtype in yield_namespace_device_dtype_combinations():
314+
for (
315+
array_namespace,
316+
device,
317+
dtype_name,
318+
) in yield_namespace_device_dtype_combinations():
315319
yield partial(
316320
check_array_api_input,
317321
array_namespace=array_namespace,
318-
dtype=dtype,
322+
dtype_name=dtype_name,
319323
device=device,
320324
)
321325

@@ -864,7 +868,7 @@ def check_array_api_input(
864868
estimator_orig,
865869
array_namespace,
866870
device=None,
867-
dtype="float64",
871+
dtype_name="float64",
868872
check_values=False,
869873
):
870874
"""Check that the estimator can work consistently with the Array API
@@ -878,7 +882,7 @@ def check_array_api_input(
878882
xp = _array_api_for_tests(array_namespace, device)
879883

880884
X, y = make_classification(random_state=42)
881-
X = X.astype(dtype, copy=False)
885+
X = X.astype(dtype_name, copy=False)
882886

883887
X = _enforce_estimator_tags_X(estimator_orig, X)
884888
y = _enforce_estimator_tags_y(estimator_orig, y)
@@ -1007,14 +1011,14 @@ def check_array_api_input_and_values(
10071011
estimator_orig,
10081012
array_namespace,
10091013
device=None,
1010-
dtype="float64",
1014+
dtype_name="float64",
10111015
):
10121016
return check_array_api_input(
10131017
name,
10141018
estimator_orig,
10151019
array_namespace=array_namespace,
10161020
device=device,
1017-
dtype=dtype,
1021+
dtype_name=dtype_name,
10181022
check_values=True,
10191023
)
10201024

sklearn/utils/tests/test_array_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_asarray_with_order_ignored():
129129

130130

131131
@pytest.mark.parametrize(
132-
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
132+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
133133
)
134134
@pytest.mark.parametrize(
135135
"sample_weight, normalize, expected",
@@ -143,20 +143,20 @@ def test_asarray_with_order_ignored():
143143
],
144144
)
145145
def test_weighted_sum(
146-
array_namespace, device, dtype, sample_weight, normalize, expected
146+
array_namespace, device, dtype_name, sample_weight, normalize, expected
147147
):
148148
xp = _array_api_for_tests(array_namespace, device)
149-
sample_score = numpy.asarray([1, 2, 3, 4], dtype=dtype)
149+
sample_score = numpy.asarray([1, 2, 3, 4], dtype=dtype_name)
150150
sample_score = xp.asarray(sample_score, device=device)
151151
if sample_weight is not None:
152-
sample_weight = numpy.asarray(sample_weight, dtype=dtype)
152+
sample_weight = numpy.asarray(sample_weight, dtype=dtype_name)
153153
sample_weight = xp.asarray(sample_weight, device=device)
154154

155155
with config_context(array_api_dispatch=True):
156156
result = _weighted_sum(sample_score, sample_weight, normalize)
157157

158158
assert isinstance(result, float)
159-
assert_allclose(result, expected, atol=_atol_for_type(dtype))
159+
assert_allclose(result, expected, atol=_atol_for_type(dtype_name))
160160

161161

162162
@skip_if_array_api_compat_not_configured

sklearn/utils/tests/test_multiclass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,17 +379,17 @@ def test_is_multilabel():
379379

380380

381381
@pytest.mark.parametrize(
382-
"array_namespace, device, dtype",
382+
"array_namespace, device, dtype_name",
383383
yield_namespace_device_dtype_combinations(),
384384
)
385-
def test_is_multilabel_array_api_compliance(array_namespace, device, dtype):
385+
def test_is_multilabel_array_api_compliance(array_namespace, device, dtype_name):
386386
xp = _array_api_for_tests(array_namespace, device)
387387

388388
for group, group_examples in ARRAY_API_EXAMPLES.items():
389389
dense_exp = group == "multilabel-indicator"
390390
for example in group_examples:
391391
if np.asarray(example).dtype.kind == "f":
392-
example = np.asarray(example, dtype=dtype)
392+
example = np.asarray(example, dtype=dtype_name)
393393
else:
394394
example = np.asarray(example)
395395
example = xp.asarray(example, device=device)

0 commit comments

Comments
 (0)
0