|
46 | 46 |
|
47 | 47 | from sklearn.linear_model import Ridge, LogisticRegression
|
48 | 48 | from sklearn.linear_model import PassiveAggressiveClassifier
|
| 49 | +from sklearn.ensemble import RandomForestClassifier |
49 | 50 | from sklearn.neighbors import KNeighborsClassifier
|
50 | 51 | from sklearn.svm import SVC
|
51 | 52 | from sklearn.cluster import KMeans
|
@@ -915,54 +916,101 @@ def test_cross_val_predict_sparse_prediction():
|
915 | 916 | assert_array_almost_equal(preds_sparse, preds)
|
916 | 917 |
|
917 | 918 |
|
918 |
| -def check_cross_val_predict_with_method(est): |
919 |
| - iris = load_iris() |
920 |
| - X, y = iris.data, iris.target |
921 |
| - X, y = shuffle(X, y, random_state=0) |
922 |
| - classes = len(set(y)) |
| 919 | +def check_cross_val_predict_with_method(est, X, y, methods): |
| 920 | + kfold = KFold(X.shape[0]) |
923 | 921 |
|
924 |
| - kfold = KFold(len(iris.target)) |
925 |
| - |
926 |
| - methods = ['decision_function', 'predict_proba', 'predict_log_proba'] |
927 | 922 | for method in methods:
|
928 | 923 | predictions = cross_val_predict(est, X, y, method=method)
|
929 |
| - assert_equal(len(predictions), len(y)) |
930 | 924 |
|
931 |
| - expected_predictions = np.zeros([len(y), classes]) |
| 925 | + if isinstance(predictions, list): |
| 926 | + assert_equal(len(predictions), y.shape[1]) |
| 927 | + for i in range(y.shape[1]): |
| 928 | + assert_equal(len(predictions[i]), len(y)) |
| 929 | + expected_predictions = [np.zeros([len(y), len(set(y[:, i]))]) |
| 930 | + for i in range(y.shape[1])] |
| 931 | + else: |
| 932 | + assert_equal(len(predictions), len(y)) |
| 933 | + expected_predictions = np.zeros_like(predictions) |
932 | 934 | func = getattr(est, method)
|
933 | 935 |
|
934 | 936 | # Naive loop (should be same as cross_val_predict):
|
935 | 937 | for train, test in kfold.split(X, y):
|
936 | 938 | est.fit(X[train], y[train])
|
937 |
| - expected_predictions[test] = func(X[test]) |
| 939 | + preds = func(X[test]) |
| 940 | + if isinstance(predictions, list): |
| 941 | + for i_label in range(y.shape[1]): |
| 942 | + expected_predictions[i_label][test] = preds[i_label] |
| 943 | + else: |
| 944 | + expected_predictions[test] = func(X[test]) |
938 | 945 |
|
939 | 946 | predictions = cross_val_predict(est, X, y, method=method,
|
940 | 947 | cv=kfold)
|
941 |
| - assert_array_almost_equal(expected_predictions, predictions) |
| 948 | + assert_array_equal_maybe_list(expected_predictions, predictions) |
942 | 949 |
|
943 | 950 | # Test alternative representations of y
|
944 | 951 | predictions_y1 = cross_val_predict(est, X, y + 1, method=method,
|
945 | 952 | cv=kfold)
|
946 |
| - assert_array_equal(predictions, predictions_y1) |
| 953 | + assert_array_equal_maybe_list(predictions, predictions_y1) |
947 | 954 |
|
948 | 955 | predictions_y2 = cross_val_predict(est, X, y - 2, method=method,
|
949 | 956 | cv=kfold)
|
950 |
| - assert_array_equal(predictions, predictions_y2) |
| 957 | + assert_array_equal_maybe_list(predictions, predictions_y2) |
951 | 958 |
|
952 | 959 | predictions_ystr = cross_val_predict(est, X, y.astype('str'),
|
953 | 960 | method=method, cv=kfold)
|
954 |
| - assert_array_equal(predictions, predictions_ystr) |
| 961 | + assert_array_equal_maybe_list(predictions, predictions_ystr) |
| 962 | + |
| 963 | + |
| 964 | +def assert_array_equal_maybe_list(x, y): |
| 965 | + # If x and y are lists of arrays, compare arrays individually. |
| 966 | + if isinstance(x, list): |
| 967 | + for i in range(len(x)): |
| 968 | + assert_array_equal(x[i], y[i]) |
| 969 | + else: |
| 970 | + assert_array_equal(x, y) |
955 | 971 |
|
956 | 972 |
|
957 | 973 | def test_cross_val_predict_with_method():
|
958 |
| - check_cross_val_predict_with_method(LogisticRegression()) |
| 974 | + iris = load_iris() |
| 975 | + X, y = iris.data, iris.target |
| 976 | + X, y = shuffle(X, y, random_state=0) |
| 977 | + methods = ['decision_function', 'predict_proba', 'predict_log_proba'] |
| 978 | + check_cross_val_predict_with_method(LogisticRegression(), X, y, methods) |
959 | 979 |
|
960 | 980 |
|
961 | 981 | def test_gridsearchcv_cross_val_predict_with_method():
|
| 982 | + iris = load_iris() |
| 983 | + X, y = iris.data, iris.target |
| 984 | + X, y = shuffle(X, y, random_state=0) |
962 | 985 | est = GridSearchCV(LogisticRegression(random_state=42),
|
963 | 986 | {'C': [0.1, 1]},
|
964 | 987 | cv=2)
|
965 |
| - check_cross_val_predict_with_method(est) |
| 988 | + methods = ['decision_function', 'predict_proba', 'predict_log_proba'] |
| 989 | + check_cross_val_predict_with_method(est, X, y, methods) |
| 990 | + |
| 991 | + |
| 992 | +def test_cross_val_predict_with_method_multilabel_ovr(): |
| 993 | + # OVR does multilabel predictions, but only arrays of |
| 994 | + # binary indicator columns. The output of predict_proba |
| 995 | + # is a 2D array with shape (n_samples, n_labels). |
| 996 | + X, y = make_multilabel_classification(n_samples=100, n_labels=3, |
| 997 | + n_classes=4, n_features=5, |
| 998 | + random_state=42) |
| 999 | + est = OneVsRestClassifier(LogisticRegression(random_state=0)) |
| 1000 | + check_cross_val_predict_with_method( |
| 1001 | + est, X, y, methods=['predict_proba', 'decision_function']) |
| 1002 | + |
| 1003 | + |
| 1004 | +def test_cross_val_predict_with_method_multilabel_rf(): |
| 1005 | + # The RandomForest allows anything for the contents of the labels. |
| 1006 | + # Output of predict_proba is a list of outputs of predict_proba |
| 1007 | + # for each individual label. |
| 1008 | + X, y = make_multilabel_classification(n_samples=100, n_labels=3, |
| 1009 | + n_classes=4, n_features=5, |
| 1010 | + random_state=42) |
| 1011 | + y[:, 0] += y[:, 1] # Put three classes in the first column |
| 1012 | + est = RandomForestClassifier(n_estimators=5, random_state=0) |
| 1013 | + check_cross_val_predict_with_method(est, X, y, methods=['predict_proba']) |
966 | 1014 |
|
967 | 1015 |
|
968 | 1016 | def get_expected_predictions(X, y, cv, classes, est, method):
|
|
0 commit comments