8000 DOC Add var_ attribute and deprecate sigma_ in GaussianNB (#18842) · scikit-learn/scikit-learn@dfc5e16 · GitHub
[go: up one dir, main page]

Skip to content

Commit dfc5e16

Browse files
hongshaoyangjeremiedbbglemaitre
authored
DOC Add var_ attribute and deprecate sigma_ in GaussianNB (#18842)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 2d95acf commit dfc5e16

File tree

3 files changed

+47
-13
lines changed

3 files changed

+47
-13
lines changed

doc/whats_new/v1.0.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ Changelog
6565
- |Enhancement| Validate user-supplied gram matrix passed to linear models
6666
via the `precompute` argument. :pr:`19004` by :user:`Adam Midvidy <amidvidy>`.
6767

68+
:mod:`sklearn.naive_bayes`
69+
..........................
70+
71+
- |API| The attribute ``sigma_`` is now deprecated in
72+
:class:`naive_bayes.GaussianNB` and will be removed in 1.2.
73+
Use ``var_`` instead.
74+
:pr:`18842` by :user:`Hong Shao Yang <hongshaoyang>`.
75+
6876
Code and Documentation Contributors
6977
-----------------------------------
7078

sklearn/naive_bayes.py

Lines changed: 25 additions & 8 deletions

Original file line numberDiff line numberDiff line change
@@ -154,7 +154,16 @@ class labels known to the classifier
154154
absolute additive value to variances
155155
156156
sigma_ : ndarray of shape (n_classes, n_features)
157-
variance of each feature per class
157+
Variance of each feature per class.
158+
159+
.. deprecated:: 1.0
160+
`sigma_` is deprecated in 1.0 and will be removed in 1.2.
161+
Use `var_` instead.
162+
163+
var_ : ndarray of shape (n_classes, n_features)
164+
Variance of each feature per class.
165+
166+
.. versionadded:: 1.0
158167
159168
theta_ : ndarray of shape (n_classes, n_features)
160169
mean of each feature per class
@@ -377,7 +386,7 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
377386
n_features = X.shape[1]
378387
n_classes = len(self.classes_)
379388
self.theta_ = np.zeros((n_classes, n_features))
380-
self.sigma_ = np.zeros((n_classes, n_features))
389+
self.var_ = np.zeros((n_classes, n_features))
381390

382391
self.class_count_ = np.zeros(n_classes, dtype=np.float64)
383392
@@ -405,7 +414,7 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
405414
msg = "Number of features %d does not match previous data %d."
406415
raise ValueError(msg % (X.shape[1], self.theta_.shape[1]))
407416
# Put epsilon back in each time
408-
self.sigma_[:, :] -= self.epsilon_
417+
self.var_[:, :] -= self.epsilon_
409418

410419
classes = self.classes_
411420

@@ -429,14 +438,14 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
429438
N_i = X_i.shape[0]
430439

431440
new_theta, new_sigma = self._update_mean_variance(
432-
self.class_count_[i], self.theta_[i, :], self.sigma_[i, :],
441+
self.class_count_[i], self.theta_[i, :], self.var_[i, :],
433442
X_i, sw_i)
434443

435444
self.theta_[i, :] = new_theta
436-
self.sigma_[i, :] = new_sigma
445+
self.var_[i, :] = new_sigma
437446
self.class_count_[i] += N_i
438447

439-
self.sigma_[:, :] += self.epsilon_
448+
self.var_[:, :] += self.epsilon_
440449

441450
# Update if only no priors is provided
442451
if self.priors is None:
@@ -449,14 +458,22 @@ def _joint_log_likelihood(self, X):
449458
joint_log_likelihood = []
450459
for i in range(np.size(self.classes_)):
451460
jointi = np.log(self.class_prior_[i])
452-
n_ij = - 0.5 * np.sum(np.log(2. * np.pi * self.sigma_[i, :]))
461+
n_ij = - 0.5 * np.sum(np.log(2. * np.pi * self.var_[i, :]))
453462
n_ij -= 0.5 * np.sum(((X - self.theta_[i, :]) ** 2) /
454-
(self.sigma_[i, :]), 1)
463+
(self.var_[i, :]), 1)
455464
joint_log_likelihood.append(jointi + n_ij)
456465

457466
joint_log_likelihood = np.array(joint_log_likelihood).T
458467
return joint_log_likelihood
459468

469+
@deprecated( # type: ignore
470+
"Attribute sigma_ was deprecated in 1.0 and will be removed in"
471+
"1.2. Use var_ instead."
472+
)
473+
@property
474+
def sigma_(self):
475+
return self.var_
476+
460477

461478
_ALPHA_MIN = 1e-10
462479

sklearn/tests/test_naive_bayes.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ def test_gnb():
5858
assert_raises(ValueError, GaussianNB().partial_fit, X, y, classes=[0, 1])
5959

6060

61+
# TODO remove in 1.2 once sigma_ attribute is removed (GH #18842)
62+
def test_gnb_var():
63+
clf = GaussianNB()
64+
clf.fit(X, y)
65+
66+
with pytest.warns(FutureWarning, match="Attribute sigma_ was deprecated"):
67+
assert_array_equal(clf.sigma_, clf.var_)
68+
69+
6170
def test_gnb_prior():
6271
# Test whether class priors are properly set.
6372
clf = GaussianNB().fit(X, y)
@@ -76,7 +85,7 @@ def test_gnb_sample_weight():
7685
clf_sw = GaussianNB().fit(X, y, sw)
7786

7887
assert_array_almost_equal(clf.theta_, clf_sw.theta_)
79-
assert_array_almost_equal(clf.sigma_, clf_sw.sigma_)
88+
assert_array_almost_equal(clf.var_, clf_sw.var_)
8089

8190
# Fitting twice with half sample-weights should result
8291
# in same result as fitting once with full weights
@@ -86,7 +95,7 @@ def test_gnb_sample_weight():
8695
clf2.partial_fit(X, y, sample_weight=sw / 2)
8796

8897
assert_array_almost_equal(clf1.theta_, clf2.theta_)
89-
assert_array_almost_equal(clf1.sigma_, clf2.sigma_)
98+
assert_array_almost_equal(clf1.var_, clf2.var_)
9099

91100
# Check that duplicate entries and correspondingly increased sample
92101
# weights yield the same result
@@ -97,7 +106,7 @@ def test_gnb_sample_weight():
97106
clf_sw = GaussianNB().fit(X, y, sample_weight)
98107

99108
assert_array_almost_equal(clf_dupl.theta_, clf_sw.theta_)
100-
assert_array_almost_equal(clf_dupl.sigma_, clf_sw.sigma_)
109+
assert_array_almost_equal(clf_dupl.var_, clf_sw.var_)
101110

102111

103112
def test_gnb_neg_priors():
@@ -174,13 +183,13 @@ def test_gnb_partial_fit():
174183
clf = GaussianNB().fit(X, y)
175184
clf_pf = GaussianNB().partial_fit(X, y, np.unique(y))
176185
assert_array_almost_equal(clf.theta_, clf_pf.theta_)
177-
assert_array_almost_equal(clf.sigma_, clf_pf.sigma_)
186+
assert_array_almost_equal(clf.var_, clf_pf.var_)
178187
assert_array_almost_equal(clf.class_prior_, clf_pf.class_prior_)
179188

180189
clf_pf2 = GaussianNB().partial_fit(X[0::2, :], y[0::2], np.unique(y))
181190
clf_pf2.partial_fit(X[1::2], y[1::2])
182191
assert_array_almost_equal(clf.theta_, clf_pf2.theta_)
183-
assert_array_almost_equal(clf.sigma_, clf_pf2.sigma_)
192+
assert_array_almost_equal(clf.var_, clf_pf2.var_)
184193
assert_array_almost_equal(clf.class_prior_, clf_pf2.class_prior_)
185194

186195

0 commit comments

Comments
 (0)
0