8000 split data using _safe_split in _permutaion_test_scorer to fix error… · scikit-learn/scikit-learn@3cf5e8f · GitHub
[go: up one dir, main page]

Skip to content

Commit 3cf5e8f

Browse files
author
Stijn Tonk
committed
split data using _safe_split in _permutaion_test_scorer to fix error when using Pandas DataFrame/Series
1 parent d97d13e commit 3cf5e8f

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

sklearn/cross_validation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,8 +1756,10 @@ def _permutation_test_score(estimator, X, y, cv, scorer):
17561756
"""Auxiliary function for permutation_test_score"""
17571757
avg_score = []
17581758
for train, test in cv:
1759-
estimator.fit(X[train], y[train])
1760-
avg_score.append(scorer(estimator, X[test], y[test]))
1759+
X_train, y_train = _safe_split(estimator, X, y, train)
1760+
X_test, y_test = _safe_split(estimator, X, y, test, train)
1761+
estimator.fit(X_train, y_train)
1762+
avg_score.append(scorer(estimator, X_test, y_test))
17611763
return np.mean(avg_score)
17621764

17631765

0 commit comments

Comments
 (0)
0