diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index fa7c7f210bc05..b53de5d51be98 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1042,14 +1042,20 @@ def cross_val_predict(estimator, X, y=None, cv=None, n_jobs=1, train, test, verbose, fit_params) for train, test in cv) - p = np.concatenate([p for p, _ in preds_blocks]) + + preds = [p for p, _ in preds_blocks] locs = np.concatenate([loc for _, loc in preds_blocks]) if not _check_is_partition(locs, _num_samples(X)): raise ValueError('cross_val_predict only works for partitions') - preds = p.copy() - preds[locs] = p - return preds - + inv_locs = np.empty(len(locs), dtype=int) + inv_locs[locs] = np.arange(len(locs)) + + # Check for sparse predictions + if sp.issparse(preds[0]): + preds = sp.vstack(preds, format=preds[0].format) + else : + preds = np.concatenate(preds) + return preds[inv_locs] def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params): """Fit estimator and predict values for a given dataset split. diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index b33e2b4c279d5..5db7c2192d312 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -4,6 +4,7 @@ import numpy as np from scipy.sparse import coo_matrix +from scipy.sparse import csr_matrix from scipy import stats from sklearn.utils.testing import assert_true @@ -25,14 +26,15 @@ from sklearn.datasets import load_boston from sklearn.datasets import load_digits from sklearn.datasets import load_iris +from sklearn.datasets import make_multilabel_classification from sklearn.metrics import explained_variance_score from sklearn.metrics import make_scorer from sklearn.metrics import precision_score - from sklearn.externals import six from sklearn.externals.six.moves import zip from sklearn.linear_model import Ridge +from sklearn.multiclass import OneVsRestClassifier from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC from sklearn.cluster import KMeans @@ -1094,3 +1096,18 @@ def test_check_is_partition(): p[0] = 23 assert_false(cval._check_is_partition(p, 100)) + +def test_cross_val_predict_sparse_prediction(): + # check that cross_val_predict gives same result for sparse and dense input + X, y = make_multilabel_classification(n_classes=2, n_labels=1, + allow_unlabeled=False, + return_indicator=True, + random_state=1) + X_sparse = csr_matrix(X) + y_sparse = csr_matrix(y) + classif = OneVsRestClassifier(SVC(kernel='linear')) + preds = cval.cross_val_predict(classif, X, y, cv=10) + preds_sparse = cval.cross_val_predict(classif, X_sparse,y_sparse, cv=10) + preds_sparse = preds_sparse.toarray() + assert_array_almost_equal(preds_sparse, preds) +