8000 FIX fix faulty test in `cross_validate` that used the wrong estimator… · dolfly/scikit-learn@b2fe974 · GitHub
[go: up one dir, main page]

Skip to content

Commit b2fe974

Browse files
FIX fix faulty test in cross_validate that used the wrong estimator (scikit-learn#25456)
1 parent 9b53739 commit b2fe974

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

sklearn/model_selection/tests/test_validation.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,9 @@ def test_cross_validate():
426426
train_r2_scores = []
427427
test_r2_scores = []
428428
fitted_estimators = []
429+
429430
for train, test in cv.split(X, y):
430-
est = clone(reg).fit(X[train], y[train])
431+
est = clone(est).fit(X[train], y[train])
431432
train_mse_scores.append(mse_scorer(est, X[train], y[train]))
432433
train_r2_scores.append(r2_scorer(est, X[train], y[train]))
433434
test_mse_scores.append(mse_scorer(est, X[test], y[test]))
@@ -448,11 +449,14 @@ def test_cross_validate():
448449
fitted_estimators,
449450
)
450451

451-
check_cross_validate_single_metric(est, X, y, scores)
452-
check_cross_validate_multi_metric(est, X, y, scores)
452+
# To ensure that the test does not suffer from
453+
# large statistical fluctuations due to slicing small datasets,
454+
# we pass the cross-validation instance
455+
check_cross_validate_single_metric(est, X, y, scores, cv)
456+
check_cross_validate_multi_metric(est, X, y, scores, cv)
453457

454458

455-
def check_cross_validate_single_metric(clf, X, y, scores):
459+
def check_cross_validate_single_metric(clf, X, y, scores, cv):
456460
(
457461
train_mse_scores,
458462
test_mse_scores,
@@ -465,12 +469,22 @@ def check_cross_validate_single_metric(clf, X, y, scores):
465469
# Single metric passed as a string
466470
if return_train_score:
467471
mse_scores_dict = cross_validate(
468-
clf, X, y, scoring="neg_mean_squared_error", return_train_score=True
472+
clf,
473+
X,
474+
y,
475+
scoring="neg_mean_squared_error",
476+
return_train_score=True,
477+
cv=cv,
469478
)
470479
assert_array_almost_equal(mse_scores_dict["train_score"], train_mse_scores)
471480
else:
472481
mse_scores_dict = cross_validate(
473-
clf, X, y, scoring="neg_mean_squared_error", return_train_score=False
482+
clf,
483+
X,
484+
y,
485+
scoring="neg_mean_squared_error",
486+
return_train_score=False,
487+
cv=cv,
474488
)
475489
assert isinstance(mse_scores_dict, dict)
476490
assert len(mse_scores_dict) == dict_len
@@ -480,27 +494,27 @@ def check_cross_validate_single_metric(clf, X, y, scores):
480494
if return_train_score:
481495
# It must be True by default - deprecated
482496
r2_scores_dict = cross_validate(
483-
clf, X, y, scoring=["r2"], return_train_score=True
497+
clf, X, y, scoring=["r2"], return_train_score=True, cv=cv
484498
)
485499
assert_array_almost_equal(r2_scores_dict["train_r2"], train_r2_scores, True)
486500
else:
487501
r2_scores_dict = cross_validate(
488-
clf, X, y, scoring=["r2"], return_train_score=False
502+
clf, X, y, scoring=["r2"], return_train_score=False, cv=cv
489503
)
490504
assert isinstance(r2_scores_dict, dict)
491505
assert len(r2_scores_dict) == dict_len
492506
assert_array_almost_equal(r2_scores_dict["test_r2"], test_r2_scores)
493507

494508
# Test return_estimator option
495509
mse_scores_dict = cross_validate(
496-
clf, X, y, scoring="neg_mean_squared_error", return_estimator=True
510+
clf, X, y, scoring="neg_mean_squared_error", return_estimator=True, cv=cv
497511
)
498512
for k, est in enumerate(mse_scores_dict["estimator"]):
499513
assert_almost_equal(est.coef_, fitted_estimators[k].coef_)
500514
assert_almost_equal(est.intercept_, fitted_estimators[k].intercept_)
501515

502516

503-
def check_cross_validate_multi_metric(clf, X, y, scores):
517+
def check_cross_validate_multi_metric(clf, X, y, scores, cv):
504518
# Test multimetric evaluation when scoring is a list / dict
505519
(
506520
train_mse_scores,
@@ -541,15 +555,15 @@ def custom_scorer(clf, X, y):
541555
if return_train_score:
542556
# return_train_score must be True by default - deprecated
543557
cv_results = cross_validate(
544-
clf, X, y, scoring=scoring, return_train_score=True
558+
clf, X, y, scoring=scoring, return_train_score=True, cv=cv
545559
)
546560
assert_array_almost_equal(cv_results["train_r2"], train_r2_scores)
547561
assert_array_almost_equal(
548562
cv_results["train_neg_mean_squared_error"], train_mse_scores
549563
)
550564
else:
551565
cv_results = cross_validate(
552-
clf, X, y, scoring=scoring, return_train_score=False
566+
clf, X, y, scoring=scoring, return_train_score=False, cv=cv
553567
)
554568
assert isinstance(cv_results, dict)
555569
assert set(cv_results.keys()) == (

0 commit comments

Comments
 (0)
0