8000 WIP simplify naive Bayes parametrization by amueller · Pull Request #1525 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

WIP simplify naive Bayes parametrization #1525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 33 additions & 38 deletions sklearn/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .preprocessing import binarize, LabelBinarizer
from .utils import array2d, atleast2d_or_csr
from .utils.extmath import safe_sparse_dot, logsumexp
from .utils import check_arrays
from .utils import check_arrays, compute_class_weight

__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB']

Expand Down Expand Up @@ -211,10 +211,6 @@ def fit(self, X, y, sample_weight=None, class_prior=None):
sample_weight : array-like, shape = [n_samples], optional
Weights applied to individual samples (1. for unweighted).

class_prior : array, shape [n_classes]
Custom prior probability per class.
Overrides the fit_prior parameter.

Returns
-------
self : object
Expand All @@ -225,7 +221,6 @@ def fit(self, X, y, sample_weight=None, class_prior=None):
labelbin = LabelBinarizer()
Y = labelbin.fit_transform(y)
self.classes_ = labelbin.classes_
n_classes = len(self.classes_)
if Y.shape[1] == 1:
Y = np.concatenate((1 - Y, Y), axis=1)

Expand All @@ -243,20 +238,22 @@ def fit(self, X, y, sample_weight=None, class_prior=None):
warnings.warn('class_prior is deprecated in fit function and will '
'be removed in version 0.15. Use the `__init__` '
'parameter class_weight instead.')
class_weight = class_prior
else:
class_prior = self.class_weight

if class_prior:
if len(class_prior) != n_classes:
raise ValueError("Number of priors must match number of"
" classes.")
self.class_log_prior_ = np.log(class_prior)
elif self.fit_prior:
# empirical prior, with sample_weight taken into account
y_freq = Y.sum(axis=0)
self.class_log_prior_ = np.log(y_freq) - np.log(y_freq.sum())
else:
self.class_log_prior_ = np.zeros(n_classes) - np.log(n_classes)
class_weight = self.class_weight

if self.fit_prior is not None:
warnings.warn('fit_prior is deprecated in fit function and will '
'be removed in version 0.15. Use the `__init__` '
'parameter class_weight instead.')
if self.fit_prior is False:
class_prior = None
else:
class_prior = 'auto'

class_weight = compute_class_weight(class_weight, self.classes_, y)
class_weight /= class_weight.sum()
self.class_log_prior_ = np.log(class_weight)

# N_c_i is the count of feature i in all samples of class c.
# N_c is the denominator.
Expand Down Expand Up @@ -295,13 +292,12 @@ class MultinomialNB(BaseDiscreteNB):
Additive (Laplace/Lidstone) smoothing parameter
(0 for no smoothing).

fit_prior : boolean
Whether to learn class prior probabilities or not.
If false, a uniform prior will be used.

class_weight : array-like, size=[n_classes,]
Prior probabilities of the classes. If specified the priors are not
adjusted according to the data.
class_weight : 'auto', None or array-like. default='auto'
Prior probabilities of the classes. 'auto' means class priors
are estimated from the data, None means uniform priors
over all classes.
If an array is given, it must have shape [n_classes,] and
explicitly specifies the class priors.

Attributes
----------
Expand All @@ -324,7 +320,7 @@ class MultinomialNB(BaseDiscreteNB):
>>> from sklearn.naive_bayes import MultinomialNB
>>> clf = MultinomialNB()
>>> clf.fit(X, Y)
MultinomialNB(alpha=1.0, class_weight=None, fit_prior=True)
MultinomialNB(alpha=1.0, class_weight='auto', fit_prior=None)
>>> print(clf.predict(X[2]))
[3]

Expand All @@ -335,7 +331,7 @@ class MultinomialNB(BaseDiscreteNB):
Tackling the poor assumptions of naive Bayes text classifiers, ICML.
"""

def __init__(self, alpha=1.0, fit_prior=True, class_weight=None):
def __init__(self, alpha=1.0, fit_prior=None, class_weight='auto'):
self.alpha = alpha
self.fit_prior = fit_prior
self.class_weight = class_weight
Expand Down Expand Up @@ -373,13 +369,12 @@ class BernoulliNB(BaseDiscreteNB):
Threshold for binarizing (mapping to booleans) of sample features.
If None, input is presumed to already consist of binary vectors.

fit_prior : boolean
Whether to learn class prior probabilities or not.
If false, a uniform prior will be used.

class_weight : array-like, size=[n_classes,]
Prior probabilities of the classes. If specified the priors are not
adjusted according to the data.
class_weight : 'auto', None or array-like. default='auto'
Prior probabilities of the classes. 'auto' means class priors
are estimated from the data, None means uniform priors
over all classes.
If an array is given, it must have shape [n_classes,] and
explicitly specifies the class priors.

Attributes
----------
Expand All @@ -397,7 +392,7 @@ class BernoulliNB(BaseDiscreteNB):
>>> from sklearn.naive_bayes import BernoulliNB
>>> clf = BernoulliNB()
>>> clf.fit(X, Y)
BernoulliNB(alpha=1.0, binarize=0.0, class_weight=None, fit_prior=True)
BernoulliNB(alpha=1.0, binarize=0.0, class_weight='auto', fit_prior=None)
>>> print(clf.predict(X[2]))
[3]

Expand All @@ -415,8 +410,8 @@ class BernoulliNB(BaseDiscreteNB):
naive Bayes -- Which naive Bayes? 3rd Conf. on Email and Anti-Spam (CEAS).
"""

def __init__(self, alpha=1.0, binarize=.0, fit_prior=True,
class_weight=None):
def __init__(self, alpha=1.0, binarize=.0, fit_prior=None,
class_weight='auto'):
self.alpha = alpha
self.binarize = binarize
self.fit_prior = fit_prior
Expand Down
3 changes: 1 addition & 2 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,8 +721,7 @@ def test_class_weight_auto_classifies():
continue

if name.endswith("NB"):
# NaiveBayes classifiers have a somewhat differnt interface.
# FIXME SOON!
# naive bayes classifiers don't work on this kind of data :(
continue

with warnings.catch_warnings(record=True):
Expand Down
9 changes: 4 additions & 5 deletions sklearn/tests/test_naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,10 @@ def test_discretenb_predict_proba():

def test_discretenb_uniform_prior():
"""Test whether discrete NB classes fit a uniform prior
when fit_prior=False and class_prior=None"""
when class_weight=None"""

for cls in [BernoulliNB, MultinomialNB]:
clf = cls()
clf.set_params(fit_prior=False)
clf = cls(class_weight=None)
clf.fit([[0], [0], [1]], [0, 0, 1])
prior = np.exp(clf.class_log_prior_)
assert_array_equal(prior, np.array([.5, .5]))
Expand All @@ -152,8 +151,8 @@ def test_discretenb_provide_prior():
"""Test whether discrete NB classes use provided prior"""

for cls in [BernoulliNB, MultinomialNB]:
clf = cls()
clf.fit([[0], [0], [1]], [0, 0, 1], class_prior=[0.5, 0.5])
clf = cls(class_weight=[.5, .5])
clf.fit([[0], [0], [1]], [0, 0, 1])
prior = np.exp(clf.class_log_prior_)
assert_array_equal(prior, np.array([.5, .5]))

Expand Down
13 changes: 8 additions & 5 deletions sklearn/utils/class_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,20 @@ def compute_class_weight(class_weight, classes, y):
weight = np.array([1.0 / np.sum(y == i) for i in classes],
dtype=np.float64, order='C')
weight *= classes.shape[0] / np.sum(weight)
else:
elif isinstance(class_weight, dict):
# user-defined dictionary
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
if not isinstance(class_weight, dict):
raise ValueError("class_weight must be dict, 'auto', or None,"
" got: %r" % class_weight)
for c in class_weight:
i = np.searchsorted(classes, c)
if classes[i] != c:
raise ValueError("Class label %d not present." % c)
else:
weight[i] = class_weight[c]

else:
# user-specified array or list
weight = np.array(class_weight)
if len(weight) != len(classes):
raise ValueError("The number of entries in class_weight is %d, "
"which does not match the number of classes %d."
% (len(weight), len(classes)))
return weight
0