@@ -757,20 +757,9 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv='warn',
757
757
758
758
cv = check_cv (cv , y , classifier = is_classifier (estimator ))
759
759
760
- # If classification methods produce multiple columns of output,
761
- # we need to manually encode classes to ensure consistent column ordering.
762
- encode = method in ['decision_function' , 'predict_proba' ,
763
- 'predict_log_proba' ]
764
- if encode :
765
- y = np .asarray (y )
766
- if y .ndim == 1 :
767
- le = LabelEncoder ()
768
- y = le .fit_transform (y )
769
- elif y .ndim == 2 :
770
- y_enc = np .zeros_like (y , dtype = np .int )
771
- for i_label in range (y .shape [1 ]):
772
- y_enc [:, i_label ] = LabelEncoder ().fit_transform (y [:, i_label ])
773
- y = y_enc
760
+ if method in ['decision_function' , 'predict_proba' , 'predict_log_proba' ]:
761
+ le = LabelEncoder ()
762
+ y = le .fit_transform (y )
774
763
775
764
# We clone the estimator to make sure that all the folds are
776
765
# independent, and that it is pickle-able.
@@ -791,26 +780,12 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv='warn',
791
780
inv_test_indices = np .empty (len (test_indices ), dtype = int )
792
781
inv_test_indices [test_indices ] = np .arange (len (test_indices ))
793
782
783
+ # Check for sparse predictions
794
784
if sp .issparse (predictions [0 ]):
795
785
predictions = sp .vstack (predictions , format = predictions [0 ].format )
796
- elif encode and isinstance (predictions [0 ], list ):
797
- # `predictions` is a list of method outputs from each fold.
798
- # If each of those is also a list, then treat this as a
799
- # multioutput-multiclass task. We need to separately concatenate
800
- # the method outputs for each label into an `n_labels` long list.
801
- n_labels = y .shape [1 ]
802
- concat_pred = []
803
- for i_label in range (n_labels ):
804
- label_preds = np .concatenate ([p [i_label ] for p in predictions ])
805
- concat_pred .append (label_preds )
806
- predictions = concat_pred
807
786
else :
808
787
predictions = np .concatenate (predictions )
809
-
810
- if isinstance (predictions , list ):
811
- return [p [inv_test_indices ] for p in predictions ]
812
- else :
813
- return predictions [inv_test_indices ]
788
+ return predictions [inv_test_indices ]
814
789
815
790
816
791
def _fit_and_predict (estimator , X , y , train , test , verbose , fit_params ,
@@ -869,76 +844,54 @@ def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
869
844
func = getattr (estimator , method )
870
845
predictions = func (X_test )
871
846
if method in ['decision_function' , 'predict_proba' , 'predict_log_proba' ]:
872
- if isinstance (predictions , list ):
873
- predictions = [_enforce_prediction_order (
874
- estimator .classes_ [i_label ], predictions [i_label ],
875
- n_classes = len (set (y [:, i_label ])), method = method )
876
- for i_label in range (len (predictions ))]
877
- else :
878
- # A 2D y array should be a binary label indicator matrix
879
- n_classes = len (set (y )) if y .ndim == 1 else y .shape [1 ]
880
- predictions = _enforce_prediction_order (
881
- estimator .classes_ , predictions , n_classes , method )
847
+ n_classes = len (set (y ))
848
+ if n_classes != len (estimator .classes_ ):
849
+ recommendation = (
850
+ 'To fix this, use a cross-validation '
851
+ 'technique resulting in properly '
852
+ 'stratified folds' )
853
+ warnings .warn ('Number of classes in training fold ({}) does '
854
+ 'not match total number of classes ({}). '
855
+ 'Results may not be appropriate for your use case. '
856
+ '{}' .format (len (estimator .classes_ ),
857
+ n_classes , recommendation ),
858
+ RuntimeWarning )
859
+ if method == 'decision_function' :
860
+ if (predictions .ndim == 2 and
861
+ predictions .shape [1 ] != len (estimator .classes_ )):
862
+ # This handles the case when the shape of predictions
863
+ # does not match the number of classes used to train
864
+ # it with. This case is found when sklearn.svm.SVC is
865
+ # set to `decision_function_shape='ovo'`.
866
+ raise ValueError ('Output shape {} of {} does not match '
867
+ 'number of classes ({}) in fold. '
868
+ 'Irregular decision_function outputs '
869
+ 'are not currently supported by '
870
+ 'cross_val_predict' .format (
871
+ predictions .shape , method ,
872
+ len (estimator .classes_ ),
873
+ recommendation ))
874
+ if len (estimator .classes_ ) <= 2 :
875
+ # In this special case, `predictions` contains a 1D array.
876
+ raise ValueError ('Only {} class/es in training fold, this '
877
+ 'is not supported for decision_function '
878
+ 'with imbalanced folds. {}' .format (
879
+ len (estimator .classes_ ),
880
+ recommendation ))
881
+
882
+ float_min = np .finfo (predictions .dtype ).min
883
+ default_values = {'decision_function' : float_min ,
884
+ 'predict_log_proba' : float_min ,
885
+ 'predict_proba' : 0.0 }
886
+ predictions_for_all_classes = np .full ((_num_samples (predictions ),
887
+ n_classes ),
888
+ default_values [method ],
889
+ predictions .dtype )
890
+ predictions_for_all_classes [:, estimator .classes_ ] = predictions
891
+ predictions = predictions_for_all_classes
882
892
return predictions , test
883
893
884
894
885
- def _enforce_prediction_order (classes , predictions , n_classes , method ):
886
- """Ensure that prediction arrays have correct column order
887
-
888
- When doing cross-validation, if one or more classes are
889
- not present in the subset of data used for training,
890
- then the output prediction array might not have the same
891
- columns as other folds. Use the list of class names
892
- (assumed to be integers) to enforce the correct column order.
893
-
894
- Note that `classes` is the list of classes in this fold
895
- (a subset of the classes in the full training set)
896
- and `n_classes` is the number of classes in the full training set.
897
- """
898
- if n_classes != len (classes ):
899
- recommendation = (
900
- 'To fix this, use a cross-validation '
901
- 'technique resulting in properly '
902
- 'stratified folds' )
903
- warnings .warn ('Number of classes in training fold ({}) does '
904
- 'not match total number of classes ({}). '
905
- 'Results may not be appropriate for your use case. '
906
- '{}' .format (len (classes ), n_classes , recommendation ),
907
- RuntimeWarning )
908
- if method == 'decision_function' :
909
- if (predictions .ndim == 2 and
910
- predictions .shape [1 ] != len (classes )):
911
- # This handles the case when the shape of predictions
912
- # does not match the number of classes used to train
913
- # it with. This case is found when sklearn.svm.SVC is
914
- # set to `decision_function_shape='ovo'`.
915
- raise ValueError ('Output shape {} of {} does not match '
916
- 'number of classes ({}) in fold. '
917
- 'Irregular decision_function outputs '
918
- 'are not currently supported by '
919
- 'cross_val_predict' .format (
920
- predictions .shape , method , len (classes )))
921
- if len (classes ) <= 2 :
922
- # In this special case, `predictions` contains a 1D array.
923
- raise ValueError ('Only {} class/es in training fold, but {} '
924
- 'in overall dataset. This '
925
- 'is not supported for decision_function '
926
- 'with imbalanced folds. {}' .format (
927
- len (classes ), n_classes , recommendation ))
928
-
929
- float_min = np .finfo (predictions .dtype ).min
930
- default_values = {'decision_function' : float_min ,
931
- 'predict_log_proba' : float_min ,
932
- 'predict_proba' : 0 }
933
- predictions_for_all_classes = np .full ((_num_samples (predictions ),
934
- n_classes ),
935
- default_values [method ],
936
- dtype = predictions .dtype )
937
- predictions_for_all_classes [:, classes ] = predictions
938
- predictions = predictions_for_all_classes
939
- return predictions
940
-
941
-
942
895
def _check_is_permutation (indices , n_samples ):
943
896
"""Check whether indices is a reordering of the array np.arange(n_samples)
944
897
0 commit comments