8000 ENH cross_val_predict and _fit_and_predict handle multi-label · scikit-learn/scikit-learn@ca2ee0b · GitHub
[go: up one dir, main page]

Skip to content

Commit ca2ee0b

Browse files
author
Stephen Hoover
committed
ENH cross_val_predict and _fit_and_predict handle multi-label
Modify the `cross_val_predict` and `_fit_and_predict` functions so that they handle multi-label (and multi-class multi-label) classification problems with `predict_proba`, `predict_log_proba`, and `decision_function` methods. There's two different kinds of multi-label outputs from scikit-learn estimators. The `OneVersusRestClassifier` handles multi-label tasks with binary indicator target arrays (no multi-label targets). It outputs 2D arrays from `predict_proba`, etc. methods. The `RandomForestClassifier` handles multi-class multi-label problems. It outputs a list of 2D arrays from `predict_proba`, etc. Recognize the RandomForest-like outputs by type-checking. Lists of 2D arrays require slightly different code for keeping track of indices.
1 parent 195de6a commit ca2ee0b

File tree

3 files changed

+126
-28
lines changed

3 files changed

+126
-28
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ Enhancements
171171
removed by setting it to `None`.
172172
:issue:`7674` by:user:`Yichuan Liu <yl565>`.
173173

174+
- Added ability for :func:`model_selection.cross_val_predict` to handle multi-label
175+
(and multi-class multi-label) targets with `predict_proba`-type methods.
176+
:issue:`8773` by:user:`Stephen Hoover <stephen-hoover>`.
177+
174178
Bug fixes
175179
.........
176180
- Fixed a bug where :class:`sklearn.ensemble.IsolationForest` uses an
@@ -5066,4 +5070,4 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
50665070
.. _Anish Shah: https://github.com/AnishShah
50675071

50685072
.. _Neeraj Gangwar: http://neerajgangwar.in
5069-
.. _Arthur Mensch: https://amensch.fr
5073+
.. _Arthur Mensch: https://amensch.fr

sklearn/model_selection/_validation.py

Lines changed: 56 additions & 10 deletions
447
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,18 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
393393
raise AttributeError('{} not implemented in estimator'
394394
.format(method))
395395

396-
if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
397-
le = LabelEncoder()
398-
y = le.fit_transform(y)
396+
do_manual_encoding = method in ['decision_function', 'predict_proba',
397+
'predict_log_proba']
398+
if do_manual_encoding:
399+
y = np.asarray(y)
400+
if y.ndim == 1:
401+
le = LabelEncoder()
402+
y = le.fit_transform(y)
403+
elif y.ndim == 2:
404+
y_enc = np.zeros_like(y, dtype=np.int)
405+
for i_label in range(y.shape[1]):
406+
y_enc[:, i_label] = LabelEncoder().fit_transform(y[:, i_label])
407+
y = y_enc
399408

400409
# We clone the estimator to make sure that all the folds are
401410
# independent, and that it is pickle-able.
@@ -419,9 +428,20 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
419428
# Check for sparse predictions
420429
if sp.issparse(predictions[0]):
421430
predictions = sp.vstack(predictions, format=predictions[0].format)
431+
elif do_manual_encoding and isinstance(predictions[0], list):
432+
n_labels = y.shape[1]
433+
concat_pred = []
434+
for i_label in range(n_labels):
435+
label_preds = np.concatenate([p[i_label] for p in predictions])
436+
concat_pred.append(label_preds)
437+
predictions = concat_pred
422438
else:
423439
predictions = np.concatenate(predictions)
424-
return predictions[inv_test_indices]
440+
441+
if do_manual_encoding and isinstance(predictions, list):
442+
return [p[inv_test_indices] for p in predictions]
443+
else:
444+
return predictions[inv_test_indices]
425445

426446

427
def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
@@ -480,16 +500,42 @@ def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
480500
func = getattr(estimator, method)
481501
predictions = func(X_test)
482502
if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
483-
n_classes = len(set(y))
484-
predictions_ = np.zeros((X_test.shape[0], n_classes))
485-
if method == 'decision_function' and len(estimator.classes_) == 2:
486-
predictions_[:, estimator.classes_[-1]] = predictions
503+
is_dec_func = (method == 'decision_function')
504+
if isinstance(predictions, list):
505+
predictions = [_enforce_prediction_order(
506+
estimator.classes_[i_label], predictions[i_label],
507+
n_classes=len(set(y[:, i_label])),
508+
one_col_if_binary=is_dec_func)
509+
for i_label in range(len(predictions))]
487510
else:
488-
predictions_[:, estimator.classes_] = predictions
489-
predictions = predictions_
511+
# A 2D y array should be a binary label indicator matrix
512+
n_classes = len(set(y)) if y.ndim == 1 else y.shape[1]
513+
predictions = _enforce_prediction_order(
514+
estimator.classes_, predictions, n_classes, is_dec_func)
515+
490516
return predictions, test
491517

492518

519+
def _enforce_prediction_order(classes, predictions, n_classes,
520+
one_col_if_binary=False):
521+
"""Ensure that prediction arrays have correct column order
522+
523+
When doing cross-validation, if one or more classes are
524+
not present in the subset of data used for training,
525+
then the output prediction array might not have the same
526+
columns as other folds. Use the list of class names
527+
(assumed to be integers) to enforce the correct column order.
528+
"""
529+
predictions_ = np.zeros((predictions.shape[0], n_classes),
530+
dtype=predictions.dtype)
531+
if one_col_if_binary and len(classes) == 2:
532+
predictions_[:, classes[-1]] = predictions
533+
else:
534+
predictions_[:, classes] = predictions
535+
predictions = predictions_
536+
return predictions
537+
538+
493539
def _check_is_permutation(indices, n_samples):
494540
"""Check whether indices is a reordering of the array np.arange(n_samples)
495541

sklearn/model_selection/tests/test_validation.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
from sklearn.linear_model import Ridge, LogisticRegression
4848
from sklearn.linear_model import PassiveAggressiveClassifier
49+
from sklearn.ensemble import RandomForestClassifier
4950
from sklearn.neighbors import KNeighborsClassifier
5051
from sklearn.svm import SVC
5152
from sklearn.cluster import KMeans
@@ -915,54 +916,101 @@ def test_cross_val_predict_sparse_prediction():
915916
assert_array_almost_equal(preds_sparse, preds)
916917

917918

918-
def check_cross_val_predict_with_method(est):
919-
iris = load_iris()
920-
X, y = iris.data, iris.target
921-
X, y = shuffle(X, y, random_state=0)
922-
classes = len(set(y))
919+
def check_cross_val_predict_with_method(est, X, y, methods):
920+
kfold = KFold(X.shape[0])
923921

924-
kfold = KFold(len(iris.target))
925-
926-
methods = ['decision_function', 'predict_proba', 'predict_log_proba']
927922
for method in methods:
928923
predictions = cross_val_predict(est, X, y, method=method)
929-
assert_equal(len(predictions), len(y))
930924

931-
expected_predictions = np.zeros([len(y), classes])
925+
if isinstance(predictions, list):
926+
assert_equal(len(predictions), y.shape[1])
927+
for i in range(y.shape[1]):
928+
assert_equal(len(predictions[i]), len(y))
929+
expected_predictions = [np.zeros([len(y), len(set(y[:, i]))])
930+
for i in range(y.shape[1])]
931+
else:
932+
assert_equal(len(predictions), len(y))
933+
expected_predictions = np.zeros_like(predictions)
932934
func = getattr(est, method)
933935

934936
# Naive loop (should be same as cross_val_predict):
935937
for train, test in kfold.split(X, y):
936938
est.fit(X[train], y[train])
937-
expected_predictions[test] = func(X[test])
939+
preds = func(X[test])
940+
if isinstance(predictions, list):
941+
for i_label in range(y.shape[1]):
942+
expected_predictions[i_label][test] = preds[i_label]
943+
else:
944+
expected_predictions[test] = func(X[test])
938945

939946
predictions = cross_val_predict(est, X, y, method=method,
940947
cv=kfold)
941-
assert_array_almost_equal(expected_predictions, predictions)
948+
assert_array_equal_maybe_list(expected_predictions, predictions)
942949

943950
# Test alternative representations of y
944951
predictions_y1 = cross_val_predict(est, X, y + 1, method=method,
945952
cv=kfold)
946-
assert_array_equal(predictions, predictions_y1)
953+
assert_array_equal_maybe_list(predictions, predictions_y1)
947954

948955
predictions_y2 = cross_val_predict(est, X, y - 2, method=method,
949956
cv=kfold)
950-
assert_array_equal(predictions, predictions_y2)
957+
assert_array_equal_maybe_list(predictions, predictions_y2)
951958

952959
predictions_ystr = cross_val_predict(est, X, y.astype('str'),
953960
method=method, cv=kfold)
954-
assert_array_equal(predictions, predictions_ystr)
961+
assert_array_equal_maybe_list(predictions, predictions_ystr)
962+
963+
964+
def assert_array_equal_maybe_list(x, y):
965+
# If x and y are lists of arrays, compare arrays individually.
966+
if isinstance(x, list):
967+
for i in range(len(x)):
968+
assert_array_equal(x[i], y[i])
969+
else:
970+
assert_array_equal(x, y)
955971

956972

957973
def test_cross_val_predict_with_method():
958-
check_cross_val_predict_with_method(LogisticRegression())
974+
iris = load_iris()
975+
X, y = iris.data, iris.target
976+
X, y = shuffle(X, y, random_state=0)
977+
methods = ['decision_function', 'predict_proba', 'predict_log_proba']
978+
check_cross_val_predict_with_method(LogisticRegression(), X, y, methods)
959979

960980

961981
def test_gridsearchcv_cross_val_predict_with_method():
982+
iris = load_iris()
983+
X, y = iris.data, iris.target
984+
X, y = shuffle(X, y, random_state=0)
962985
est = GridSearchCV(LogisticRegression(random_state=42),
963986
{'C': [0.1, 1]},
964987
cv=2)
965-
check_cross_val_predict_with_method(est)
988+
methods = ['decision_function', 'predict_proba', 'predict_log_proba']
989+
check_cross_val_predict_with_method(est, X, y, methods)
990+
991+
992+
def test_cross_val_predict_with_method_multilabel_ovr():
993+
# OVR does multilabel predictions, but only arrays of
994+
# binary indicator columns. The output of predict_proba
995+
# is a 2D array with shape (n_samples, n_labels).
996+
X, y = make_multilabel_classification(n_samples=100, n_labels=3,
997+
n_classes=4, n_features=5,
998+
random_state=42)
999+
est = OneVsRestClassifier(LogisticRegression(random_state=0))
1000+
check_cross_val_predict_with_method(
1001+
est, X, y, methods=['predict_proba', 'decision_function'])
1002+
1003+
1004+
def test_cross_val_predict_with_method_multilabel_rf():
1005+
# The RandomForest allows anything for the contents of the labels.
1006+
# Output of predict_proba is a list of outputs of predict_proba
1007+
# for each individual label.
1008+
X, y = make_multilabel_classification(n_samples=100, n_labels=3,
1009+
n_classes=4, n_features=5,
1010+
random_state=42)
1011+
y[:, 0] += y[:, 1] # Put three classes in the first column
1012+
est = RandomForestClassifier(n_estimators=5, random_state=0)
1013+
check_cross_val_predict_with_method(est, X, y, methods=['predict_proba'])
9661014

9671015

9681016
def get_expected_predictions(X, y, cv, classes, est, method):

0 commit comments

Comments
 (0)
0