8000 ENH break ties in OvO using scores · deepatdotnet/scikit-learn@7379aae · GitHub
[go: up one dir, main page]

Skip to content

Commit 7379aae

Browse files
amuellerlarsmans
authored andcommitted
ENH break ties in OvO using scores
1 parent 437644e commit 7379aae

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

sklearn/multiclass.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,16 +304,28 @@ def predict_ovo(estimators, classes, X):
304304
n_samples = X.shape[0]
8000
305305
n_classes = classes.shape[0]
306306
votes = np.zeros((n_samples, n_classes))
307+
scores = np.zeros((n_samples, n_classes))
307308

308309
k = 0
309310
for i in range(n_classes):
310311
for j in range(i + 1, n_classes):
311312
pred = estimators[k].predict(X)
313+
score = _predict_binary(estimators[k], X)
314+
scores[:, 0] += score
315+
scores[:, 1] -= score
312316
votes[pred == 0, i] += 1
313317
votes[pred == 1, j] += 1
314318
k += 1
319+
# find all places with maximum votes per sample
320+
maxima = votes == np.max(votes, axis=1)[:, np.newaxis]
315321

316-
return classes[votes.argmax(axis=1)]
322+
# if there are ties, use scores to break them
323+
if np.any(maxima.sum(axis=1) > 1):
324+
scores[~maxima] = -np.inf
325+
prediction = scores.argmax(axis=1)
326+
else:
327+
prediction = votes.argmax(axis=1)
328+
return classes[prediction]
317329

318330

319331
class OneVsOneClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):

sklearn/tests/test_multiclass.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from sklearn.multiclass import OutputCodeClassifier
1515
from sklearn.svm import LinearSVC
1616
from sklearn.naive_bayes import MultinomialNB
17-
from sklearn.linear_model import LinearRegression, Lasso, ElasticNet, Ridge
17+
from sklearn.linear_model import (LinearRegression, Lasso, ElasticNet, Ridge,
18+
Perceptron)
1819
from sklearn.tree import DecisionTreeClassifier
1920
from sklearn.grid_search import GridSearchCV
2021
from sklearn.pipeline import Pipeline
@@ -260,6 +261,36 @@ def test_ovo_gridsearch():
260261
assert_true(best_C in Cs)
261262

262263

264+
def test_ovo_ties():
265+
# test that ties are broken using the decision function, not defaulting to
266+
# the smallest label
267+
X = np.array([[1, 2], [2, 1], [-2, 1], [-2, -1]])
268+
y = np.array([2, 0, 1, 2])
269+
multi_clf = OneVsOneClassifier(Perceptron())
270+
ovo_prediction = multi_clf.fit(X, y).predict(X)
271+
272+
# recalculate votes to make sure we have a tie
273+
predictions = np.vstack([clf.predict(X) for clf in multi_clf.estimators_])
274+
scores = np.vstack([clf.decision_function(X)
275+
for clf in multi_clf.estimators_])
276+
# classifiers are in order 0-1, 0-2, 1-2
277+
# aggregate votes:
278+
votes = np.zeros((4, 3))
279+
votes[np.arange(4), predictions[0]] += 1
280+
votes[np.arange(4), 2 * predictions[1]] += 1
281+
votes[np.arange(4), 1 + predictions[2]] += 1
282+
# for the first point, there is one vote per class
283+
assert_array_equal(votes[0, :], 1)
284+
# for the rest, there is no tie and the prediction is the argmax
285+
assert_array_equal(np.argmax(votes[1:], axis=1), ovo_prediction[1:])
286+
# for the tie, the prediction is the class with the highest score
287+
assert_equal(ovo_prediction[0], 1)
288+
# score for one is greater than score for zero
289+
assert_greater(scores[2, 0] - scores[0, 0], scores[0, 0] + scores[1, 0])
290+
# score for one is greater than score for two
291+
assert_greater(scores[2, 0] - scores[0, 0], -scores[1, 0] - scores[2, 0])
292+
293+
263294
def test_ecoc_exceptions():
264295
ecoc = OutputCodeClassifier(LinearSVC(random_state=0))
265296
assert_raises(ValueError, ecoc.predict, [])

0 commit comments

Comments
 (0)
0