8000 [MRG+1] Return list instead of 3d array for MultiOutputClassifier.pre… · scikit-learn/scikit-learn@dd2e48c · GitHub
[go: up one dir, main page]

Skip to content

Commit dd2e48c

Browse files
pjbullraghavrv
authored andcommitted
[MRG+1] Return list instead of 3d array for MultiOutputClassifier.predict_proba (#8095)
* Return list instead of 3d array for MultiOutputClassifier.predict_proba * Update flake8, docstring, variable name - Changed `rs` to `rng` to follow convention. - Made sure changes were flake8 approved - Add `\` to continue docstring for `predict_proba` return value. * Sub random.choice for np.random.choice `np.random.choice` isn’t available in Numpy 1.6, so opt for the Python version instead. * Make test labels deterministic * Remove hanging chad... * Add bug fix and API change to whats new
1 parent b685494 commit dd2e48c

File tree

3 files changed

+64
-8
lines changed

3 files changed

+64
-8
lines changed

doc/whats_new.rst

+15
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ Bug fixes
152152
wrong values when calling ``__call__``.
153153
:issue:`8087` by :user:`Alexis Mignon <AlexisMignon>`
154154

155+
- Fix :func:`sklearn.multioutput.MultiOutputClassifier.predict_proba` to
156+
return a list of 2d arrays, rather than a 3d array. In the case where
157+
different target columns had different numbers of classes, a `ValueError`
158+
would be 8000 raised on trying to stack matrices with different dimensions.
159+
:issue:`8093` by :user:`Peter Bull <pjbull>`.
160+
155161
API changes summary
156162
-------------------
157163

@@ -167,6 +173,15 @@ API changes summary
167173
needed for the perplexity calculation. :issue:`7954` by
168174
:user:`Gary Foreman <garyForeman>`.
169175

176+
- The :func:`sklearn.multioutput.MultiOutputClassifier.predict_proba`
177+
function used to return a 3d array (``n_samples``, ``n_classes``,
178+
``n_outputs``). In the case where different target columns had different
179+
numbers of classes, a `ValueError` would be raised on trying to stack
180+
matrices with different dimensions. This function now returns a list of
181+
arrays where the length of the list is ``n_outputs``, and each array is
182+
(``n_samples``, ``n_classes``) for that particular output.
183+
:issue:`8093` by :user:`Peter Bull <pjbull>`.
184+
170185
.. _changes_0_18_1:
171186

172187
Version 0.18.1

sklearn/multioutput.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -214,16 +214,18 @@ def predict_proba(self, X):
214214
215215
Returns
216216
-------
217-
T : (sparse) array-like, shape = (n_samples, n_classes, n_outputs)
218-
The class probabilities of the samples for each of the outputs
217+
p : array of shape = [n_samples, n_classes], or a list of n_outputs \
218+
such arrays if n_outputs > 1.
219+
The class probabilities of the input samples. The order of the
220+
classes corresponds to that in the attribute `classes_`.
219221
"""
220222
check_is_fitted(self, 'estimators_')
221223
if not hasattr(self.estimator, "predict_proba"):
222224
raise ValueError("The base estimator should implement"
223225
"predict_proba method")
224226

225-
results = np.dstack([estimator.predict_proba(X) for estimator in
226-
self.estimators_])
227+
results = [estimator.predict_proba(X) for estimator in
228+
self.estimators_]
227229
return results
228230

229231
def score(self, X, y):

sklearn/tests/test_multioutput.py

+43-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn import datasets
1111
from sklearn.base import clone
1212
from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier
13-
from sklearn.linear_model import Lasso
13+
from sklearn.linear_model import Lasso, LogisticRegression
1414
from sklearn.svm import LinearSVC
1515
from sklearn.multiclass import OneVsRestClassifier
1616
from sklearn.multioutput import MultiOutputRegressor, MultiOutputClassifier
@@ -118,17 +118,21 @@ def test_multi_output_classification():
118118
assert_equal((n_samples, n_outputs), predictions.shape)
119119

120120
predict_proba = multi_target_forest.predict_proba(X)
121-
assert_equal((n_samples, n_classes, n_outputs), predict_proba.shape)
122121

123-
assert_array_equal(np.argmax(predict_proba, axis=1), predictions)
122+
assert len(predict_proba) == n_outputs
123+
for class_probabilities in predict_proba:
124+
assert_equal((n_samples, n_classes), class_probabilities.shape)
125+
126+
assert_array_equal(np.argmax(np.dstack(predict_proba), axis=1),
127+
predictions)
124128

125129
# train the forest with each column and assert that predictions are equal
126130
for i in range(3):
127131
forest_ = clone(forest) # create a clone with the same state
128132
forest_.fit(X, y[:, i])
129133
assert_equal(list(forest_.predict(X)), list(predictions[:, i]))
130134
assert_array_equal(list(forest_.predict_proba(X)),
131-
list(predict_proba[:, :, i]))
135+
list(predict_proba[i]))
132136

133137

134138
def test_multiclass_multioutput_estimator():
@@ -150,6 +154,41 @@ def test_multiclass_multioutput_estimator():
150154
list(predictions[:, i]))
151155

152156

157+
def test_multiclass_multioutput_estimator_predict_proba():
158+
seed = 542
159+
160+
# make test deterministic
161+
rng = np.random.RandomState(seed)
162+
163+
# random features
164+
X = rng.normal(size=(5, 5))
165+
166+
# random labels
167+
y1 = np.array(['b', 'a', 'a', 'b', 'a']).reshape(5, 1) # 2 classes
168+
y2 = np.array(['d', 'e', 'f', 'e', 'd']).reshape(5, 1) # 3 classes
169+
170+
Y = np.concatenate([y1, y2], axis=1)
171+
172+
clf = MultiOutputClassifier(LogisticRegression(random_state=seed))
173+
174+
clf.fit(X, Y)
175+
176+
y_result = clf.predict_proba(X)
177+
y_actual = [np.array([[0.23481764, 0.76518236],
178+
[0.67196072, 0.32803928],
179+
[0.54681448, 0.45318552],
180+
[0.34883923, 0.65116077],
181+
[0.73687069, 0.26312931]]),
182+
np.array([[0.5171785, 0.23878628, 0.24403522],
183+
[0.22141451, 0.64102704, 0.13755846],
184+
[0.16751315, 0.18256843, 0.64991843],
185+
[0.27357372, 0.55201592, 0.17441036],
186+
[0.65745193, 0.26062899, 0.08191907]])]
187+
188+
for i in range(len(y_actual)):
189+
assert_almost_equal(y_result[i], y_actual[i])
190+
191+
153192
def test_multi_output_classification_sample_weights():
154193
# weighted classifier
155194
Xw = [[1, 2, 3], [4, 5, 6]]

0 commit comments

Comments
 (0)
0