24
24
from sklearn .utils .testing import assert_less
25
25
from sklearn .utils .testing import assert_array_almost_equal
26
26
from sklearn .utils .testing import assert_array_equal
27
+ from sklearn .utils .testing import assert_allclose
27
28
from sklearn .utils .mocking import CheckingClassifier , MockDataFrame
28
29
29
30
from sklearn .model_selection import cross_val_score , ShuffleSplit
@@ -1333,8 +1334,8 @@ def check_cross_val_predict_binary(est, X, y, method):
1333
1334
1334
1335
# Check actual outputs for several representations of y
1335
1336
for tg in [y , y + 1 , y - 2 , y .astype ('str' )]:
1336
- assert_array_equal (cross_val_predict (est , X , tg , method = method , cv = cv ),
1337
- expected_predictions )
1337
+ assert_allclose (cross_val_predict (est , X , tg , method = method , cv = cv ),
1338
+ expected_predictions )
1338
1339
1339
1340
1340
1341
def check_cross_val_predict_multiclass (est , X , y , method ):
@@ -1358,8 +1359,8 @@ def check_cross_val_predict_multiclass(est, X, y, method):
1358
1359
1359
1360
# Check actual outputs for several representations of y
1360
1361
for tg in [y , y + 1 , y - 2 , y .astype ('str' )]:
1361
- assert_array_equal (cross_val_predict (est , X , tg , method = method , cv = cv ),
1362
- expected_predictions )
1362
+ assert_allclose (cross_val_predict (est , X , tg , method = method , cv = cv ),
1363
+ expected_predictions )
1363
1364
1364
1365
1365
1366
def check_cross_val_predict_multilabel (est , X , y , method ):
@@ -1406,7 +1407,7 @@ def check_cross_val_predict_multilabel(est, X, y, method):
1406
1407
cv_predict_output = cross_val_predict (est , X , tg , method = method , cv = cv )
1407
1408
assert_equal (len (cv_predict_output ), len (expected_preds ))
1408
1409
for i in range (len (cv_predict_output )):
1409
- assert_array_equal (cv_predict_output [i ], expected_preds [i ])
1410
+ assert_allclose (cv_predict_output [i ], expected_preds [i ])
1410
1411
1411
1412
1412
1413
def check_cross_val_predict_with_method_binary (est ):
0 commit comments