8000 TST use approximate equality for float comparison (#13749) · scikit-learn/scikit-learn@cc0179a · GitHub
[go: up one dir, main page]

Skip to content

Commit cc0179a

Browse files
authored
TST use approximate equality for float comparison (#13749)
1 parent e7bd8a3 commit cc0179a

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

sklearn/model_selection/tests/test_validation.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sklearn.utils.testing import assert_less
2525
from sklearn.utils.testing import assert_array_almost_equal
2626
from sklearn.utils.testing import assert_array_equal
27+
from sklearn.utils.testing import assert_allclose
2728
from sklearn.utils.mocking import CheckingClassifier, MockDataFrame
2829

2930
from sklearn.model_selection import cross_val_score, ShuffleSplit
@@ -1333,8 +1334,8 @@ def check_cross_val_predict_binary(est, X, y, method):
13331334

13341335
# Check actual outputs for several representations of y
13351336
for tg in [y, y + 1, y - 2, y.astype('str')]:
1336-
assert_array_equal(cross_val_predict(est, X, tg, method=method, cv=cv),
1337-
expected_predictions)
1337+
assert_allclose(cross_val_predict(est, X, tg, method=method, cv=cv),
1338+
expected_predictions)
13381339

13391340

13401341
def check_cross_val_predict_multiclass(est, X, y, method):
@@ -1358,8 +1359,8 @@ def check_cross_val_predict_multiclass(est, X, y, method):
13581359

13591360
# Check actual outputs for several representations of y
13601361
for tg in [y, y + 1, y - 2, y.astype('str')]:
1361-
assert_array_equal(cross_val_predict(est, X, tg, method=method, cv=cv),
1362-
expected_predictions)
1362+
assert_allclose(cross_val_predict(est, X, tg, method=method, cv=cv),
1363+
expected_predictions)
13631364

13641365

13651366
def check_cross_val_predict_multilabel(est, X, y, method):
@@ -1406,7 +1407,7 @@ def check_cross_val_predict_multilabel(est, X, y, method):
14061407
cv_predict_output = cross_val_predict(est, X, tg, method=method, cv=cv)
14071408
assert_equal(len(cv_predict_output), len(expected_preds))
14081409
for i in range(len(cv_predict_output)):
1409-
assert_array_equal(cv_predict_output[i], expected_preds[i])
1410+
assert_allclose(cv_predict_output[i], expected_preds[i])
14101411

14111412

14121413
def check_cross_val_predict_with_method_binary(est):

0 commit comments

Comments
 (0)
0