8000 OneVsOneClassifier.partial_fit checks classes are valid subset (#9156) · raghavrv/scikit-learn@0d5d842 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0d5d842

Browse files
authored
OneVsOneClassifier.partial_fit checks classes are valid subset (scikit-learn#9156)
1 parent b4b5de8 commit 0d5d842

File tree

3 files changed

+51
-16
lines changed

3 files changed

+51
-16
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,15 @@ Bug fixes
404404
reused the same estimator for each parameter value.
405405
:issue:`7365` by :user:`Aleksandr Sandrovskii <Sundrique>`.
406406

407+
- :class:`multiclass.OneVsOneClassifier`'s ``partial_fit`` now ensures all
408+
classes are provided up-front. :issue:`6250` by
409+
:user:`Asish Panda <kaichogami>`.
410+
407411
- Fixed an integer overflow bug in :func:`metrics.confusion_matrix` and
408412
hence :func:`metrics.cohen_kappa_score`. :issue:`8354`, :issue:`7929`
409413
by `Joel Nothman`_ and :user:`Jon Crall <Erotemic>`.
410414

415+
411416
API changes summary
412417
-------------------
413418

sklearn/multiclass.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def partial_fit(self, X, y, classes=None):
257257
self.label_binarizer_ = LabelBinarizer(sparse_output=True)
258258
self.label_binarizer_.fit(self.classes_)
259259

260-
if np.setdiff1d(y, self.classes_):
260+
if len(np.setdiff1d(y, self.classes_)):
261261
raise ValueError(("Mini-batch contains {0} while classes " +
262262
"must be subset of {1}").format(np.unique(y),
263263
self.classes_))
@@ -429,9 +429,11 @@ def _partial_fit_ovo_binary(estimator, X, y, i, j):
429429

430430
cond = np.logical_or(y == i, y == j)
431431
y = y[cond]
432-
y_binary = np.zeros_like(y)
433-
y_binary[y == j] = 1
434-
return _partial_fit_binary(estimator, X[cond], y_binary)
432+
if len(y) != 0:
433+
y_binary = np.zeros_like(y)
434+
y_binary[y == j] = 1
435+
return _partial_fit_binary(estimator, X[cond], y_binary)
436+
return estimator
435437

436438

437439
class OneVsOneClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
@@ -544,6 +546,11 @@ def partial_fit(self, X, y, classes=None):
544546
range(self.n_classes_ *
545547
(self.n_classes_ - 1) // 2)]
546548

549+
if len(np.setdiff1d(y, self.classes_)):
550+
raise ValueError("Mini-batch contains {0} while it "
551+
"must be subset of {1}".format(np.unique(y),
552+
self.classes_))
553+
547554
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
548555
check_classification_targets(y)
549556
combinations = itertools.combinations(range(self.n_classes_), 2)

sklearn/tests/test_multiclass.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import numpy as np
22
import scipy.sparse as sp
33

4-
from sklearn.utils.testing import assert_array_equal, assert_raises_regex
4+
from re import escape
5+
6+
from sklearn.utils.testing import assert_array_equal
57
from sklearn.utils.testing import assert_equal
68
from sklearn.utils.testing import assert_almost_equal
79
from sklearn.utils.testing import assert_true
@@ -10,6 +12,7 @@
1012
from sklearn.utils.testing import assert_warns
1113
from sklearn.utils.testing import assert_greater
1214
from sklearn.utils.testing import assert_raise_message
15+
from sklearn.utils.testing import assert_raises_regexp
1316
from sklearn.multiclass import OneVsRestClassifier
1417
from sklearn.multiclass import OneVsOneClassifier
1518
from sklearn.multiclass import OutputCodeClassifier
@@ -118,9 +121,9 @@ def test_ovr_partial_fit_exceptions():
118121
# A new class value which was not in the first call of partial_fit
119122
# It should raise ValueError
120123
y1 = [5] + y[7:-1]
121-
assert_raises_regex(ValueError, "Mini-batch contains \[.+\] while classes"
122-
" must be subset of \[.+\]",
123-
ovr.partial_fit, X=X[7:], y=y1)
124+
assert_raises_regexp(ValueError, "Mini-batch contains \[.+\] while classes"
125+
" must be subset of \[.+\]",
126+
ovr.partial_fit, X=X[7:], y=y1)
124127

125128

126129
def test_ovr_ovo_regressor():
@@ -493,7 +496,8 @@ def test_ovo_fit_predict():
493496

494497

495498
def test_ovo_partial_fit_predict():
496-
X, y = shuffle(iris.data, iris.target)
499+
temp = datasets.load_iris()
500+
X, y = temp.data, temp.target
497501
ovo1 = OneVsOneClassifier(MultinomialNB())
498502
ovo1.partial_fit(X[:100], y[:100], np.unique(y))
499503
ovo1.partial_fit(X[100:], y[100:])
@@ -506,17 +510,36 @@ def test_ovo_partial_fit_predict():
506510
assert_greater(np.mean(y == pred1), 0.65)
507511
assert_almost_equal(pred1, pred2)
508512

509-
# Test when mini-batches don't have all target classes
513+
# Test when mini-batches have binary target classes
510514
ovo1 = OneVsOneClassifier(MultinomialNB())
511-
ovo1.partial_fit(iris.data[:60], iris.target[:60], np.unique(iris.target))
512-
ovo1.partial_fit(iris.data[60:], iris.target[60:])
513-
pred1 = ovo1.predict(iris.data)
515+
6D47 ovo1.partial_fit(X[:60], y[:60], np.unique(y))
516+
ovo1.partial_fit(X[60:], y[60:])
517+
pred1 = ovo1.predict(X)
514518
ovo2 = OneVsOneClassifier(MultinomialNB())
515-
pred2 = ovo2.fit(iris.data, iris.target).predict(iris.data)
519+
pred2 = ovo2.fit(X, y).predict(X)
516520

517521
assert_almost_equal(pred1, pred2)
518-
assert_equal(len(ovo1.estimators_), len(np.unique(iris.target)))
519-
assert_greater(np.mean(iris.target == pred1), 0.65)
522+
assert_equal(len(ovo1.estimators_), len(np.unique(y)))
523+
assert_greater(np.mean(y == pred1), 0.65)
524+
525+
ovo = OneVsOneClassifier(MultinomialNB())
526+
X = np.random.rand(14, 2)
527+
y = [1, 1, 2, 3, 3, 0, 0, 4, 4, 4, 4, 4, 2, 2]
528+
ovo.partial_fit(X[:7], y[:7], [0, 1, 2, 3, 4])
529+
ovo.partial_fit(X[7:], y[7:])
530+
pred = ovo.predict(X)
531+
ovo2 = OneVsOneClassifier(MultinomialNB())
532+
pred2 = ovo2.fit(X, y).predict(X)
533+
assert_almost_equal(pred, pred2)
534+
535+
# raises error when mini-batch does not have classes from all_classes
536+
ovo = OneVsOneClassifier(MultinomialNB())
537+
error_y = [0, 1, 2, 3, 4, 5, 2]
538+
message_re = escape("Mini-batch contains {0} while "
539+
"it must be subset of {1}".format(np.unique(error_y),
540+
np.unique(y)))
541+
assert_raises_regexp(ValueError, message_re, ovo.partial_fit, X[:7],
542+
error_y, np.unique(y))
520543

521544
# test partial_fit only exists if estimator has it:
522545
ovr = OneVsOneClassifier(SVC())

0 commit comments

Comments
 (0)
0