@@ -426,8 +426,9 @@ def test_cross_validate():
426
426
train_r2_scores = []
427
427
test_r2_scores = []
428
428
fitted_estimators = []
429
+
429
430
for train , test in cv .split (X , y ):
430
- est = clone (reg ).fit (X [train ], y [train ])
431
+ est = clone (est ).fit (X [train ], y [train ])
431
432
train_mse_scores .append (mse_scorer (est , X [train ], y [train ]))
432
433
train_r2_scores .append (r2_scorer (est , X [train ], y [train ]))
433
434
test_mse_scores .append (mse_scorer (est , X [test ], y [test ]))
@@ -448,11 +449,14 @@ def test_cross_validate():
448
449
fitted_estimators ,
449
450
)
450
451
451
- check_cross_validate_single_metric (est , X , y , scores )
452
- check_cross_validate_multi_metric (est , X , y , scores )
452
+ # To ensure that the test does not suffer from
453
+ # large statistical fluctuations due to slicing small datasets,
454
+ # we pass the cross-validation instance
455
+ check_cross_validate_single_metric (est , X , y , scores , cv )
456
+ check_cross_validate_multi_metric (est , X , y , scores , cv )
453
457
454
458
455
- def check_cross_validate_single_metric (clf , X , y , scores ):
459
+ def check_cross_validate_single_metric (clf , X , y , scores , cv ):
456
460
(
457
461
train_mse_scores ,
458
462
test_mse_scores ,
@@ -465,12 +469,22 @@ def check_cross_validate_single_metric(clf, X, y, scores):
465
469
# Single metric passed as a string
466
470
if return_train_score :
467
471
mse_scores_dict = cross_validate (
468
- clf , X , y , scoring = "neg_mean_squared_error" , return_train_score = True
472
+ clf ,
473
+ X ,
474
+ y ,
475
+ scoring = "neg_mean_squared_error" ,
476
+ return_train_score = True ,
477
+ cv = cv ,
469
478
)
470
479
assert_array_almost_equal (mse_scores_dict ["train_score" ], train_mse_scores )
471
480
else :
472
481
mse_scores_dict = cross_validate (
473
- clf , X , y , scoring = "neg_mean_squared_error" , return_train_score = False
482
+ clf ,
483
+ X ,
484
+ y ,
485
+ scoring = "neg_mean_squared_error" ,
486
+ return_train_score = False ,
487
+ cv = cv ,
474
488
)
475
489
assert isinstance (mse_scores_dict , dict )
476
490
assert len (mse_scores_dict ) == dict_len
@@ -480,27 +494,27 @@ def check_cross_validate_single_metric(clf, X, y, scores):
480
494
if return_train_score :
481
495
# It must be True by default - deprecated
482
496
r2_scores_dict = cross_validate (
483
- clf , X , y , scoring = ["r2" ], return_train_score = True
497
+ clf , X , y , scoring = ["r2" ], return_train_score = True , cv = cv
484
498
)
485
499
assert_array_almost_equal (r2_scores_dict ["train_r2" ], train_r2_scores , True )
486
500
else :
487
501
r2_scores_dict = cross_validate (
488
- clf , X , y , scoring = ["r2" ], return_train_score = False
502
+ clf , X , y , scoring = ["r2" ], return_train_score = False , cv = cv
489
503
)
490
504
assert isinstance (r2_scores_dict , dict )
491
505
assert len (r2_scores_dict ) == dict_len
492
506
assert_array_almost_equal (r2_scores_dict ["test_r2" ], test_r2_scores )
493
507
494
508
# Test return_estimator option
495
509
mse_scores_dict = cross_validate (
496
- clf , X , y , scoring = "neg_mean_squared_error" , return_estimator = True
510
+ clf , X , y , scoring = "neg_mean_squared_error" , return_estimator = True , cv = cv
497
511
)
498
512
for k , est in enumerate (mse_scores_dict ["estimator" ]):
499
513
assert_almost_equal (est .coef_ , fitted_estimators [k ].coef_ )
500
514
assert_almost_equal (est .intercept_ , fitted_estimators [k ].intercept_ )
501
515
502
516
503
- def check_cross_validate_multi_metric (clf , X , y , scores ):
517
+ def check_cross_validate_multi_metric (clf , X , y , scores , cv ):
504
518
# Test multimetric evaluation when scoring is a list / dict
505
519
(
506
520
train_mse_scores ,
@@ -541,15 +555,15 @@ def custom_scorer(clf, X, y):
541
555
if return_train_score :
542
556
# return_train_score must be True by default - deprecated
543
557
cv_results = cross_validate (
544
- clf , X , y , scoring = scoring , return_train_score = True
558
+ clf , X , y , scoring = scoring , return_train_score = True , cv = cv
545
559
)
546
560
assert_array_almost_equal (cv_results ["train_r2" ], train_r2_scores )
547
561
assert_array_almost_equal (
548
562
cv_results ["train_neg_mean_squared_error" ], train_mse_scores
549
563
)
550
564
else :
551
565
cv_results = cross_validate (
552
- clf , X , y , scoring = scoring , return_train_score = False
566
+ clf , X , y , scoring = scoring , return_train_score = False , cv = cv
553
567
)
554
568
assert isinstance (cv_results , dict )
555
569
assert set (cv_results .keys ()) == (
0 commit comments