8000 FIX Split data using _safe_split in _permutaion_test_score (#5697) · paulha/scikit-learn@7d3918d · GitHub
[go: up one dir, main page]

Skip to content

Commit 7d3918d

Browse files
Stijn Tonkpaulha
authored andcommitted
FIX Split data using _safe_split in _permutaion_test_score (scikit-learn#5697)
Squashed commits: [94fd9f4] split data using _safe_split in _permutaion_test_scorer [522053b] adding test case test_permutation_test_score_pandas() to check if permutation_test_score plays nice with pandas dataframe/series [21b23ce] running test_permutation_test_score_pandas on iris data to prevent warnings. [15a48bf] adding safe_indexing to _shuffle function [9ea5c9e] adding test case test_permutation_test_score_pandas() to check if permutation_test_score plays nice with pandas dataframe/series [3cf5e8f] split data using _safe_split in _permutaion_test_scorer to fix error when using Pandas DataFrame/Series
1 parent 612077f commit 7d3918d

File tree

3 files changed

+29
-6
lines changed

3 files changed

+29
-6
lines changed

sklearn/cross_validation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,8 +1756,10 @@ def _permutation_test_score(estimator, X, y, cv, scorer):
17561756
"""Auxiliary function for permutation_test_score"""
17571757
avg_score = []
17581758
for train, test in cv:
1759-
estimator.fit(X[train], y[train])
1760-
avg_score.append(scorer(estimator, X[test], y[test]))
1759+
X_train, y_train = _safe_split(estimator, X, y, train)
1760+
X_test, y_test = _safe_split(estimator, X, y, test, train)
1761+
estimator.fit(X_train, y_train)
1762+
avg_score.append(scorer(estimator, X_test, y_test))
17611763
return np.mean(avg_score)
17621764

17631765

@@ -1770,7 +1772,7 @@ def _shuffle(y, labels, random_state):
17701772
for label in np.unique(labels):
17711773
this_mask = (labels == label)
17721774
ind[this_mask] = random_state.permutation(ind[this_mask])
1773-
return y[ind]
1775+
return safe_indexing(y, ind)
17741776

17751777

17761778
def check_cv(cv, X=None, y=None, classifier=False):

sklearn/model_selection/_validation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,10 @@ def _permutation_test_score(estimator, X, y, groups, cv, scorer):
622622
"""Auxiliary function for permutation_test_score"""
623623
avg_score = []
624624
for train, test in cv.split(X, y, groups):
625-
estimator.fit(X[train], y[train])
626-
avg_score.append(scorer(estimator, X[test], y[test]))
625+
X_train, y_train = _safe_split(estimator, X, y, train)
626+
X_test, y_test = _safe_split(estimator, X, y, test, train)
627+
estimator.fit(X_train, y_train)
628+
avg_score.append(scorer(estimator, X_test, y_test))
627629
return np.mean(avg_score)
628630

629631

@@ -636,7 +638,7 @@ def _shuffle(y, groups, random_state):
636638
for group in np.unique(groups):
637639
this_mask = (groups == group)
638640
indices[this_mask] = random_state.permutation(indices[this_mask])
639-
return y[indices]
641+
return safe_indexing(y, indices)
640642

641643

642644
def learning_curve(estimator, X, y, groups=None,

sklearn/model_selection/tests/test_validation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,3 +966,22 @@ def test_score_memmap():
966966
break
967967
except WindowsError:
968968
sleep(1.)
969+
970+
971+
def test_permutation_test_score_pandas():
972+
# check permutation_test_score doesn't destroy pandas dataframe
973+
types = [(MockDataFrame, MockDataFrame)]
974+
try:
975+
from pandas import Series, DataFrame
976+
types.append((Series, DataFrame))
977+
except ImportError:
978+
pass
979+
for TargetType, InputFeatureType in types:
980+
# X dataframe, y series
981+
iris = load_iris()
982+
X, y = iris.data, iris.target
983+
X_df, y_ser = InputFeatureType(X), TargetType(y)
984+
check_df = lambda x: isinstance(x, InputFeatureType)
985+
check_series = lambda x: isinstance(x, TargetType)
986+
clf = CheckingClassifier(check_X=check_df, check_y=check_series)
987+
permutation_test_score(clf, X_df, y_ser)

0 commit comments

Comments
 (0)
0