8000 [MRG + 1] Fix the cross_val_predict function for method='predict_prob… · scikit-learn/scikit-learn@fd84a56 · GitHub
[go: up one dir, main page]

Skip to content

Commit fd84a56

Browse files
dalmiajnothman
authored andcommitted
[MRG + 1] Fix the cross_val_predict function for method='predict_proba' (#7889)
Handle the case where different CV splits have different sets of classes present.
1 parent 4910e11 commit fd84a56

File tree

3 files changed

+94
-1
lines changed

3 files changed

+94
-1
lines changed

doc/whats_new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ Enhancements
110110
attributes, ``n_skips_*``.
111111
:issue:`7914` by :user:`Michael Horrell <mthorrell>`.
112112

113+
- :func:`model_selection.cross_val_predict` now returns output of the
114+
correct shape for all values of the argument ``method``.
115+
:issue:`7863` by :user:`Aman Dalmia <dalmia>`.
116+
113117
- Fix a bug where :class:`sklearn.feature_selection.SelectFdr` did not
114118
exactly implement Benjamini-Hochberg procedure. It formerly may have
115119
selected fewer features than it should.

sklearn/model_selection/_validation.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ..metrics.scorer import check_scoring
2929
from ..exceptions import FitFailedWarning
3030
from ._split import check_cv
31+
from ..preprocessing import LabelEncoder
3132

3233
__all__ = ['cross_val_score', 'cross_val_predict', 'permutation_test_score',
3334
'learning_curve', 'validation_curve']
@@ -364,7 +365,9 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
364365
as in '2*n_jobs'
365366
366367
method : string, optional, default: 'predict'
367-
Invokes the passed method name of the passed estimator.
368+
Invokes the passed method name of the passed estimator. For
369+
method='predict_proba', the columns correspond to the classes
370+
in sorted order.
368371
369372
Returns
370373
-------
@@ -390,6 +393,10 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
390393
raise AttributeError('{} not implemented in estimator'
391394
.format(method))
392395

396+
if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
397+
le = LabelEncoder()
398+
y = le.fit_transform(y)
399+
393400
# We clone the estimator to make sure that all the folds are
394401
# independent, and that it is pickle-able.
395402
parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
@@ -472,6 +479,14 @@ def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
472479
estimator.fit(X_train, y_train, **fit_params)
473480
func = getattr(estimator, method)
474481
predictions = func(X_test)
482+
if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
483+
n_classes = len(set(y))
484+
predictions_ = np.zeros((X_test.shape[0], n_classes))
485+
if method == 'decision_function' and len(estimator.classes_) == 2:
486+
predictions_[:, estimator.classes_[-1]] = predictions
487+
else:
488+
predictions_[:, estimator.classes_] = predictions
489+
predictions = predictions_
475490
return predictions, test
476491

477492

sklearn/model_selection/tests/test_validation.py

+74
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from sklearn.cluster import KMeans
5252

5353
from sklearn.preprocessing import Imputer
54+
from sklearn.preprocessing import LabelEncoder
5455
from sklearn.pipeline import Pipeline
5556

5657
from sklearn.externals.six.moves import cStringIO as StringIO
@@ -940,6 +941,79 @@ def test_cross_val_predict_with_method():
940941
cv=kfold)
941942
assert_array_almost_equal(expected_predictions, predictions)
942943

944+
# Test alternative representations of y
945+
predictions_y1 = cross_val_predict(est, X, y + 1, method=method,
946+
cv=kfold)
947+
assert_array_equal(predictions, predictions_y1)
948+
949+
predictions_y2 = cross_val_predict(est, X, y - 2, method=method,
950+
cv=kfold)
951+
assert_array_equal(predictions, predictions_y2)
952+
953+
predictions_ystr = cross_val_predict(est, X, y.astype('str'),
954+
method=method, cv=kfold)
955+
assert_array_equal(predictions, predictions_ystr)
956+
957+
958+
def get_expected_predictions(X, y, cv, classes, est, method):
959+
960+
expected_predictions = np.zeros([len(y), classes])
961+
func = getattr(est, method)
962+
963+
for train, test in cv.split(X, y):
964+
est.fit(X[train], y[train])
965+
expected_predictions_ = func(X[test])
966+
# To avoid 2 dimensional indexing
967+
exp_pred_test = np.zeros((len(test), classes))
968+
if method is 'decision_function' and len(est.classes_) == 2:
969+
exp_pred_test[:, est.classes_[-1]] = expected_predictions_
970+
else:
971+
exp_pred_test[:, est.classes_] = expected_predictions_
972+
expected_predictions[test] = exp_pred_test
973+
974+
return expected_predictions
975+
976+
977+
def test_cross_val_predict_class_subset():
978+
979+
X = np.arange(8).reshape(4, 2)
980+
y = np.array([0, 0, 1, 2])
981+
classes = 3
982+
983+
kfold3 = KFold(n_splits=3)
984+
kfold4 = KFold(n_splits=4)
985+
986+
le = LabelEncoder()
987+
988+
methods = ['decision_function', 'predict_proba', 'predict_log_proba']
989+
for method in methods:
990+
est = LogisticRegression()
991+
992+
# Test with n_splits=3
993+
predictions = cross_val_predict(est, X, y, method=method,
994+
cv=kfold3)
995+
996+
# Runs a naive loop (should be same as cross_val_predict):
997+
expected_predictions = get_expected_predictions(X, y, kfold3, classes,
998+
est, method)
999+
assert_array_almost_equal(expected_predictions, predictions)
1000+
1001+
# Test with n_splits=4
1002+
predictions = cross_val_predict(est, X, y, method=method,
1003+
cv=kfold4)
1004+
expected_predictions = get_expected_predictions(X, y, kfold4, classes,
1005+
est, method)
1006+
assert_array_almost_equal(expected_predictions, predictions)
1007+
1008+
# Testing unordered labels
1009+
y = [1, 1, -4, 6]
1010+
predictions = cross_val_predict(est, X, y, method=method,
1011+
cv=kfold3)
1012+
y = le.fit_transform(y)
1013+
expected_predictions = get_expected_predictions(X, y, kfold3, classes,
1014+
est, method)
1015+
assert_array_almost_equal(expected_predictions, predictions)
1016+
9431017

9441018
def test_score_memmap():
9451019
# Ensure a scalar score of memmap type is accepted

0 commit comments

Comments
 (0)
0