8000 split data using _safe_split in _permutaion_test_scorer · scikit-learn/scikit-learn@94fd9f4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 94fd9f4

Browse files
Stijn TonkStijn Tonk
Stijn Tonk
authored and
Stijn Tonk
committed
split data using _safe_split in _permutaion_test_scorer
1 parent 522053b commit 94fd9f4

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

sklearn/model_selection/_validation.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,10 @@ def _permutation_test_score(estimator, X, y, groups, cv, scorer):
622622
"""Auxiliary function for permutation_test_score"""
623623
avg_score = []
624624
for train, test in cv.split(X, y, groups):
625-
estimator.fit(X[train], y[train])
626-
avg_score.append(scorer(estimator, X[test], y[test]))
625+
X_train, y_train = _safe_split(estimator, X, y, train)
626+
X_test, y_test = _safe_split(estimator, X, y, test, train)
627+
estimator.fit(X_train, y_train)
628+
avg_score.append(scorer(estimator, X_test, y_test))
627629
return np.mean(avg_score)
628630

629631

0 commit comments

Comments
 (0)
0