@@ -648,36 +648,11 @@ def test_stratified_kfold_preserve_order(): # see #2372
648
648
649
649
650
650
def test_stratified_kfold_preserve_order_with_digits (): # see #2372
651
- # The digits samples are dependent as they are apparently grouped
652
- # by authors although we don't have any information on the groups
653
- # segment locations for this data. We can highlight this fact be
654
- # computing k-fold cross-validation with and without shuffling: we
655
- # observer that the shuffling case makes the IID assumption and is
656
- # therefore too optimistic: it estimates a much higher accuracy
657
- # (around 0.965) than than the non shuffling variant (around
658
- # 0.905).
659
-
651
+ # A regression test, taken from
652
+ # http://nbviewer.ipython.org/urls/raw.github.com/ogrisel/notebooks/master/Non%2520IID%2520cross-validation.ipynb
660
653
digits = load_digits ()
661
- X , y = digits .data [:800 ], digits .target [:800 ]
662
- model = SVC (C = 10 , gamma = 0.005 )
663
- n = len (y )
664
-
665
- cv = cval .KFold (n , 5 , shuffle = False )
666
- assert_greater (0.91 , cval .cross_val_score (model , X , y , cv = cv ).mean ())
667
-
668
- cv = cval .KFold (n , 5 , shuffle = True , random_state = 0 )
669
- assert_greater (cval .cross_val_score (model , X , y , cv = cv ).mean (), 0.95 )
670
-
671
- cv = cval .KFold (n , 5 , shuffle = True , random_state = 1 )
672
- assert_greater (cval .cross_val_score (model , X , y , cv = cv ).mean (), 0.95 )
673
-
674
- cv = cval .KFold (n , 5 , shuffle = True , random_state = 2 )
675
- assert_greater (cval .cross_val_score (model , X , y , cv = cv ).mean (), 0.95 )
676
-
677
- # Similarly, StratifiedKFold should try to shuffle the data as few
678
- # as possible (while respecting the balanced class constraints)
679
- # and thus be able to detect the dependency by not overestimating
680
- # the CV score either:
654
+ X , y = digits .data , digits .target
681
655
656
+ model = SVC (C = 10 , gamma = 0.005 )
682
657
cv = cval .StratifiedKFold (y , 5 )
683
- assert_greater ( 0.91 , cval .cross_val_score (model , X , y , cv = cv ).mean ())
658
+ assert cval .cross_val_score (model , X , y , cv = cv , n_jobs = - 1 ).mean () < 0.91
0 commit comments