8000 [MRG+1] correct comparison in GaussianNB for 'priors' (#10005) · scikit-learn/scikit-learn@e41c4d5 · GitHub
[go: up one dir, main page]

Skip to content

Commit e41c4d5

Browse files
gxydTomDLT
authored andcommitted
[MRG+1] correct comparison in GaussianNB for 'priors' (#10005)
1 parent 8472350 commit e41c4d5

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

sklearn/naive_bayes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
374374
raise ValueError('Number of priors must match number of'
375375
' classes.')
376376
# Check that the sum is 1
377-
if priors.sum() != 1.0:
377+
if not np.isclose(priors.sum(), 1.0):
378378
raise ValueError('The sum of the priors should be 1.')
379379
# Check that the prior are non-negative
380380
if (priors < 0).any():

sklearn/tests/test_naive_bayes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ def test_gnb_priors():
114114
assert_array_almost_equal(clf.class_prior_, np.array([0.3, 0.7]))
115115

116116

117+
def test_gnb_priors_sum_isclose():
118+
# test whether the class prior sum is properly tested"""
119+
X = np.array([[-1, -1], [-2, -1], [-3, -2], [-4, -5], [-5, -4],
120+
[1, 1], [2, 1], [3, 2], [4, 4], [5, 5]])
121+
priors = np.array([0.08, 0.14, 0.03, 0.16, 0.11, 0.16, 0.07, 0.14,
122+
0.11, 0.0])
123+
Y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
124+
clf = GaussianNB(priors)
125+
# smoke test for issue #9633
126+
clf.fit(X, Y)
127+
128+
117129
def test_gnb_wrong_nb_priors():
118130
""" Test whether an error is raised if the number of prior is different
119131
from the number of class"""

0 commit comments

Comments
 (0)
0