8000 ENH/MNT Use _fit_and_score instead of _permutation_test_score · scikit-learn/scikit-learn@8cfebf2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8cfebf2

Browse files
committed
ENH/MNT Use _fit_and_score instead of _permutation_test_score
1 parent b18b161 commit 8cfebf2

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

sklearn/model_selection/_validation.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import warnings
1616
import numbers
1717
import time
18+
from itertools import chain
1819

1920
import numpy as np
2021
import scipy.sparse as sp
@@ -592,29 +593,27 @@ def permutation_test_score(estimator, X, y, groups=None, cv=None,
592593

593594
# We clone the estimator to make sure that all the folds are
594595
# independent, and that it is pickle-able.
595-
score = _permutation_test_score(clone(estimator), X, y, groups, cv, scorer)
596-
permutation_scores = Parallel(n_jobs=n_jobs, verbose=verbose)(
597-
delayed(_permutation_test_score)(
598-
clone(estimator), X, _shuffle(y, groups, random_state),
599-
groups, cv, scorer)
600-
for _ in range(n_permutations))
601-
permutation_scores = np.array(permutation_scores)
596+
score = cross_val_score(clone(estimator), X, y, groups=groups, cv=cv,
597+
scoring=scorer).mean()
598+
y_shuflled = (_shuffle(y=y, groups=groups, random_state=random_state)
599+
for _ in range(n_permutations))
600+
jobs = ((delayed(_fit_and_score)(clone(estimator), X, y_i, scorer, train, test,
601+
verbose, parameters=None, fit_params=None)
602+
for train, test in cv.split(X, y_i, groups=groups))
603+
for y_i in y_shuflled)
604+
out = Parallel(n_jobs=n_jobs, verbose=verbose)(chain.from_iterable(jobs))
605+
permutation_scores = zip(*out)[0]
606+
n_splits = len(permutation_scores) // n_permutations
607+
permutation_scores = np.array(permutation_scores).reshape(n_splits,
608+
n_permutations)
609+
permutation_scores = permutation_scores.mean(axis=0)
602610
pvalue = (np.sum(permutation_scores >= score) + 1.0) / (n_permutations + 1)
603611
return score, permutation_scores, pvalue
604612

605613

606614
permutation_test_score.__test__ = False # to avoid a pb with nosetests
607615

608616

609-
def _permutation_test_score(estimator, X, y, groups, cv, scorer):
610-
"""Auxiliary function for permutation_test_score"""
611-
avg_score = []
612-
for train, test in cv.split(X, y, groups):
613-
estimator.fit(X[train], y[train])
614-
avg_score.append(scorer(estimator, X[test], y[test]))
615-
return np.mean(avg_score)
616-
617-
618617
def _shuffle(y, groups, random_state):
619618
"""Return a shuffled copy of y eventually shuffle among same groups."""
620619
if groups is None:

0 commit comments

Comments
 (0)
0