10000 Revert "ENH cross_val_predict now handles multi-output predict_proba … · xhluca/scikit-learn@c6aa3ea · GitHub
[go: up one dir, main page]

Skip to content

Commit c6aa3ea

Browse files
author
Xing
authored
Revert "ENH cross_val_predict now handles multi-output predict_proba (scikit-learn#8773)"
This reverts commit 0eaceb6.
1 parent 40bed9f commit c6aa3ea

File tree

3 files changed

+86
-292
lines changed

3 files changed

+86
-292
lines changed

doc/whats_new/v0.21.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -430,10 +430,6 @@ Support for Python 3.4 and below has been officially dropped.
430430
making ``shuffle=True`` ineffective.
431431
:issue:`13124` by :user:`Hanmin Qin <qinhanmin2014>`.
432432

433-
- |Fix| Added ability for :func:`model_selection.cross_val_predict` to handle
434-
multi-label (and multioutput-multiclass) targets with ``predict_proba``-type
435-
methods. :issue:`8773` by :user:`Stephen Hoover <stephen-hoover>`.
436-
437433
- |Fix| Fixed an issue in :func:`~model_selection.cross_val_predict` where
438434
`method="predict_proba"` returned always `0.0` when one of the classes was
439435
excluded in a cross-validation fold.

sklearn/model_selection/_validation.py

Lines changed: 50 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -757,20 +757,9 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv='warn',
757757

758758
cv = check_cv(cv, y, classifier=is_classifier(estimator))
759759

760-
# If classification methods produce multiple columns of output,
761-
# we need to manually encode classes to ensure consistent column ordering.
762-
encode = method in ['decision_function', 'predict_proba',
763-
'predict_log_proba']
764-
if encode:
765-
y = np.asarray(y)
766-
if y.ndim == 1:
767-
le = LabelEncoder()
768-
y = le.fit_transform(y)
769-
elif y.ndim == 2:
770-
y_enc = np.zeros_like(y, dtype=np.int)
771-
for i_label in range(y.shape[1]):
772-
y_enc[:, i_label] = LabelEncoder().fit_transform(y[:, i_label])
773-
y = y_enc
760+
if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
761+
le = LabelEncoder()
762+
y = le.fit_transform(y)
774763

775764
# We clone the estimator to make sure that all the folds are
776765
# independent, and that it is pickle-able.
@@ -791,26 +780,12 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv='warn',
791780
inv_test_indices = np.empty(len(test_indices), dtype=int)
792781
inv_test_indices[test_indices] = np.arange(len(test_indices))
793782

783+
# Check for sparse predictions
794784
if sp.issparse(predictions[0]):
795785
predictions = sp.vstack(predictions, format=predictions[0].format)
796-
elif encode and isinstance(predictions[0], list):
797-
# `predictions` is a list of method outputs from each fold.
798-
# If each of those is also a list, then treat this as a
799-
# multioutput-multiclass task. We need to separately concatenate
800-
# the method outputs for each label into an `n_labels` long list.
801-
n_labels = y.shape[1]
802-
concat_pred = []
803-
for i_label in range(n_labels):
804-
label_preds = np.concatenate([p[i_label] for p in predictions])
805-
concat_pred.append(label_preds)
806-
predictions = concat_pred
807786
else:
808787
predictions = np.concatenate(predictions)
809-
810-
if isinstance(predictions, list):
811-
return [p[inv_test_indices] for p in predictions]
812-
else:
813-
return predictions[inv_test_indices]
788+
return predictions[inv_test_indices]
814789

815790

816791
def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
@@ -869,76 +844,54 @@ def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
869844
func = getattr(estimator, method)
870845
predictions = func(X_test)
871846
if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
872-
if isinstance(predictions, list):
873-
predictions = [_enforce_prediction_order(
874-
estimator.classes_[i_label], predictions[i_label],
875-
n_classes=len(set(y[:, i_label])), method=method)
876-
for i_label in range(len(predictions))]
877-
else:
878-
# A 2D y array should be a binary label indicator matrix
879-
n_classes = len(set(y)) if y.ndim == 1 else y.shape[1]
880-
predictions = _enforce_prediction_order(
881-
estimator.classes_, predictions, n_classes, method)
847+
n_classes = len(set(y))
848+
if n_classes != len(estimator.classes_):
849+
recommendation = (
850+
'To fix this, use a cross-validation '
851+
'technique resulting in properly '
852+
'stratified folds')
853+
warnings.warn('Number of classes in training fold ({}) does '
854+
'not match total number of classes ({}). '
855+
'Results may not be appropriate for your use case. '
856+
'{}'.format(len(estimator.classes_),
857+
n_classes, recommendation),
858+
RuntimeWarning)
859+
if method == 'decision_function':
860+
if (predictions.ndim == 2 and
861+
predictions.shape[1] != len(estimator.classes_)):
862+
# This handles the case when the shape of predictions
863+
# does not match the number of classes used to train
864+
# it with. This case is found when sklearn.svm.SVC is
865+
# set to `decision_function_shape='ovo'`.
866+
raise ValueError('Output shape {} of {} does not match '
867+
'number of classes ({}) in fold. '
868+
'Irregular decision_function outputs '
869+
'are not currently supported by '
870+
'cross_val_predict'.format(
871+
predictions.shape, method,
872+
len(estimator.classes_),
873+
recommendation))
874+
if len(estimator.classes_) <= 2:
875+
# In this special case, `predictions` contains a 1D array.
876+
raise ValueError('Only {} class/es in training fold, this '
877+
'is not supported for decision_function '
878+
'with imbalanced folds. {}'.format(
879+
len(estimator.classes_),
880+
recommendation))
881+
882+
float_min = np.finfo(predictions.dtype).min
883+
default_values = {'decision_function': float_min,
884+
'predict_log_proba': float_min,
885+
'predict_proba': 0.0}
886+
predictions_for_all_classes = np.full((_num_samples(predictions),
887+
n_classes),
888+
default_values[method],
889+
predictions.dtype)
890+
predictions_for_all_classes[:, estimator.classes_] = predictions
891+
predictions = predictions_for_all_classes
882892
return predictions, test
883893

884894

885-
def _enforce_prediction_order(classes, predictions, n_classes, method):
886-
"""Ensure that prediction arrays have correct column order
887-
888-
When doing cross-validation, if one or more classes are
889-
not present in the subset of data used for training,
890-
then the output prediction array might not have the same
891-
columns as other folds. Use the list of class names
892-
(assumed to be integers) to enforce the correct column order.
893-
894-
Note that `classes` is the list of classes in this fold
895-
(a subset of the classes in the full training set)
896-
and `n_classes` is the number of classes in the full training set.
897-
"""
898-
if n_classes != len(classes):
899-
recommendation = (
900-
'To fix this, use a cross-validation '
901-
'technique resulting in properly '
902-
'stratified folds')
903-
warnings.warn('Number of classes in training fold ({}) does '
904-
'not match total number of classes ({}). '
905-
'Results may not be appropriate for your use case. '
906-
'{}'.format(len(classes), n_classes, recommendation),
907-
RuntimeWarning)
908-
if method == 'decision_function':
909-
if (predictions.ndim == 2 and
910-
predictions.shape[1] != len(classes)):
911-
# This handles the case when the shape of predictions
912-
# does not match the number of classes used to train
913-
# it with. This case is found when sklearn.svm.SVC is
914-
# set to `decision_function_shape='ovo'`.
915-
raise ValueError('Output shape {} of {} does not match '
916-
'number of classes ({}) in fold. '
917-
'Irregular decision_function outputs '
918-
'are not currently supported by '
919-
'cross_val_predict'.format(
920-
predictions.shape, method, len(classes)))
921-
if len(classes) <= 2:
922-
# In this special case, `predictions` contains a 1D array.
923-
raise ValueError('Only {} class/es in training fold, but {} '
924-
'in overall dataset. This '
925-
'is not supported for decision_function '
926-
'with imbalanced folds. {}'.format(
927-
len(classes), n_classes, recommendation))
928-
929-
float_min = np.finfo(predictions.dtype).min
930-
default_values = {'decision_function': float_min,
931-
'predict_log_proba': float_min,
932-
'predict_proba': 0}
933-
predictions_for_all_classes = np.full((_num_samples(predictions),
934-
n_classes),
935-
default_values[method],
936-
dtype=predictions.dtype)
937-
predictions_for_all_classes[:, classes] = predictions
938-
predictions = predictions_for_all_classes
939-
return predictions
940-
941-
942895
def _check_is_permutation(indices, n_samples):
943896
"""Check whether indices is a reordering of the array np.arange(n_samples)
944897

0 commit comments

Comments
 (0)
0