8000 [MRG+1] Fix MultinomialNB and BernoulliNB alpha=0 bug (continuation) … · scikit-learn/scikit-learn@b4b5de8 · GitHub
[go: up one dir, main page]

Skip to content

Commit b4b5de8

Browse files
herilalainajmschrei
authored andcommitted
[MRG+1] Fix MultinomialNB and BernoulliNB alpha=0 bug (continuation) (#9131)
* Fix #5814 * Fix pep8 in naive_bayes.py:716 * Fix sparse matrix incompatibility * Fix python 2.7 problem in test_naive_bayes * Make sure the values are probabilities before log transform * Improve docstring of `_safe_logprob` * Clip alpha solution * Clip alpha solution * Clip alpha in fit and partial_fit * Add what's new entry * Add test * Remove .project * Replace assert method * Update what's new * Format float into %.1e * Update ValueError msg
1 parent d9b525a commit b4b5de8

File tree

3 files changed

+75
-8
lines changed

3 files changed

+75
-8
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ Bug fixes
259259
- Fixed a bug where :func:`linear_model.RANSACRegressor.fit` may run until
260260
``max_iter`` if finds a large inlier group early. :issue:`8251` by :user:`aivision2020`.
261261

262+
- Fixed a bug where :class:`sklearn.naive_bayes.MultinomialNB` and :class:`sklearn.naive_bayes.BernoulliNB`
263+
failed when `alpha=0`. :issue:`5814` by :user:`Yichuan Liu <yl565>` and
264+
:user:`Herilalaina Rakotoarison <herilalaina>`.
265+
262266
- Fixed a bug where :func:`datasets.make_moons` gives an
263267
incorrect result when ``n_samples`` is odd.
264268
:issue:`8198` by :user:`Josh Levy <levy5674>`.

sklearn/naive_bayes.py

Lines changed: 22 additions & 7 deletions
< 9E88 td data-grid-cell-id="diff-9a6f9b7534ccf863fd2609da993e500e5df5e804fc7775857dff8b5b0b6838de-796-811-1" data-selected="false" role="gridcell" style="background-color:var(--bgColor-default);text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative diff-line-number-neutral left-side">811
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# (parts based on earlier work by Mathieu Blondel)
1616
#
1717
# License: BSD 3 clause
18+
import warnings
1819

1920
from abc import ABCMeta, abstractmethod
2021

@@ -436,6 +437,8 @@ def _joint_log_likelihood(self, X):
436437
joint_log_likelihood = np.array(joint_log_likelihood).T
437438
return joint_log_likelihood
438439

440+
_ALPHA_MIN = 1e-10
441+
439442

440443
class BaseDiscreteNB(BaseNB):
441444
"""Abstract base class for naive Bayes on discrete/categorical data
@@ -460,6 +463,16 @@ def _update_class_log_prior(self, class_prior=None):
460463
else:
461464
self.class_log_prior_ = np.zeros(n_classes) - np.log(n_classes)
462465

466+
def _check_alpha(self):
467+
if self.alpha < 0:
468+
raise ValueError('Smoothing parameter alpha = %.1e. '
469+
'alpha should be > 0.' % self.alpha)
470+
if self.alpha < _ALPHA_MIN:
471+
warnings.warn('alpha too small will result in numeric errors, '
472+
'setting alpha = %.1e' % _ALPHA_MIN)
473+
return _ALPHA_MIN
474+
return self.alpha
475+
463476
def partial_fit(self, X, y, classes=None, sample_weight=None):
464477
"""Incremental fit on a batch of samples.
465478
@@ -538,7 +551,8 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
538551
# be called by the user explicitly just once after several consecutive
539552
# calls to partial_fit and prior any call to predict[_[log_]proba]
540553
# to avoid computing the smooth log probas at each call to partial fit
541-
self._update_feature_log_prob()
554+
alpha = self._check_alpha()
555+
self._update_feature_log_prob(alpha)
542556
self._update_class_log_prior(class_prior=class_prior)
543557
return self
544558

@@ -588,7 +602,8 @@ def fit(self, X, y, sample_weight=None):
588602
self.feature_count_ = np.zeros((n_effective_classes, n_features),
589603
dtype=np.float64)
590604
self._count(X, Y)
591-
self._update_feature_log_prob()
605+
alpha = self._check_alpha()
606+
self._update_feature_log_prob(alpha)
592607
self._update_class_log_prior(class_prior=class_prior)
593608
return self
594609

@@ -694,9 +709,9 @@ def _count(self, X, Y):
694709
self.feature_count_ += safe_sparse_dot(Y.T, X)
695710
self.class_count_ += Y.sum(axis=0)
696711

697-
def _update_feature_log_prob(self):
712+
def _update_feature_log_prob(self, alpha):
698713
"""Apply smoothing to raw counts and recompute log probabilities"""
699-
smoothed_fc = self.feature_count_ + self.alpha
714+
smoothed_fc = self.feature_count_ + alpha
700715
smoothed_cc = smoothed_fc.sum(axis=1)
701716

702717
self.feature_log_prob_ = (np.log(smoothed_fc) -
@@ -796,10 +811,10 @@ def _count(self, X, Y):
796
self.feature_count_ += safe_sparse_dot(Y.T, X)
797812
self.class_count_ += Y.sum(axis=0)
798813

799-
def _update_feature_log_prob(self):
814+
def _update_feature_log_prob(self, alpha):
800815
"""Apply smoothing to raw counts and recompute log probabilities"""
801-
smoothed_fc = self.feature_count_ + self.alpha
802-
smoothed_cc = self.class_count_ + self.alpha * 2
816+
smoothed_fc = self.feature_count_ + alpha
817+
smoothed_cc = self.class_count_ + alpha * 2
803818

804819
self.feature_log_prob_ = (np.log(smoothed_fc) -
805820
np.log(smoothed_cc.reshape(-1, 1)))

sklearn/tests/test_naive_bayes.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from sklearn.utils.testing import assert_array_almost_equal
1515
from sklearn.utils.testing import assert_equal
1616
from sklearn.utils.testing import assert_raises
17+
from sklearn.utils.testing import assert_raise_message
1718
from sklearn.utils.testing import assert_greater
19+
from sklearn.utils.testing import assert_warns
1820

1921
from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB
2022

@@ -480,7 +482,7 @@ def test_feature_log_prob_bnb():
480482
denom = np.tile(np.log(clf.class_count_ + 2.0), (X.shape[1], 1)).T
481483

482484
# Check manual estimate matches
483-
assert_array_equal(clf.feature_log_prob_, (num - denom))
485+
assert_array_almost_equal(clf.feature_log_prob_, (num - denom))
484486

485487

486488
def test_bnb():
@@ -536,3 +538,49 @@ def test_naive_bayes_scale_invariance():
536538
for f in [1E-10, 1, 1E10]]
537539
assert_array_equal(labels[0], labels[1])
538540
assert_array_equal(labels[1], labels[2])
541+
542+
543+
def test_alpha():
544+
# Setting alpha=0 should not output nan results when p(x_i|y_j)=0 is a case
545+
X = np.array([[1, 0], [1, 1]])
546+
y = np.array([0, 1])
547+
nb = BernoulliNB(alpha=0.)
548+
assert_warns(UserWarning, nb.partial_fit, X, y, classes=[0, 1])
549+
assert_warns(UserWarning, nb.fit, X, y)
550+
prob = np.array([[1, 0], [0, 1]])
551+
assert_array_almost_equal(nb.predict_proba(X), prob)
552+
553+
nb = MultinomialNB(alpha=0.)
554+
assert_warns(UserWarning, nb.partial_fit, X, y, classes=[0, 1])
555+
assert_warns(UserWarning, nb.fit, X, y)
556+
prob = np.array([[2./3, 1./3], [0, 1]])
557+
assert_array_almost_equal(nb.predict_proba(X), prob)
558+
559+
# Test sparse X
560+
X = scipy.sparse.csr_matrix(X)
561+
nb = BernoulliNB(alpha=0.)
562+
assert_warns(UserWarning, nb.fit, X, y)
563+
prob = np.array([[1, 0], [0, 1]])
564+
assert_array_almost_equal(nb.predict_proba(X), prob)
565+
566+
nb = MultinomialNB(alpha=0.)
567+
assert_warns(UserWarning, nb.fit, X, y)
568+
prob = np.array([[2./3, 1./3], [0, 1]])
569+
assert_array_almost_equal(nb.predict_proba(X), prob)
570+
571+
# Test for alpha < 0
572+
X = np.array([[1, 0], [1, 1]])
573+
y = np.array([0, 1])
574+
expected_msg = ('Smoothing parameter alpha = -1.0e-01. '
575+
'alpha should be > 0.')
576+
b_nb = BernoulliNB(alpha=-0.1)
577+
m_nb = MultinomialNB(alpha=-0.1)
578+
assert_raise_message(ValueError, expected_msg, b_nb.fit, X, y)
579+
assert_raise_message(ValueError, expected_msg, m_nb.fit, X, y)
580+
581+
b_nb = BernoulliNB(alpha=-0.1)
582+
m_nb = MultinomialNB(alpha=-0.1)
583+
assert_raise_message(ValueError, expected_msg, b_nb.partial_fit,
584+
X, y, classes=[0, 1])
585+
assert_raise_message(ValueError, expected_msg, m_nb.partial_fit,
586+
X, y, classes=[0, 1])

0 commit comments

Comments
 (0)
0