10000 Merge pull request #3344 from ihaque/remove_dup_fit · scikit-learn/scikit-learn@1775095 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1775095

Browse files
committed
Merge pull request #3344 from ihaque/remove_dup_fit
Remove duplicate GaussianNB.fit() code
2 parents 08b2902 + 2d12952 commit 1775095

File tree

2 files changed

+36
-23
lines changed

2 files changed

+36
-23
lines changed

sklearn/naive_bayes.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -159,27 +159,7 @@ def fit(self, X, y):
159159
self : object
160160
Returns self.
161161
"""
162-
163-
X, y = check_arrays(X, y, sparse_format='dense')
164-
y = column_or_1d(y, warn=True)
165-
166-
n_samples, n_features = X.shape
167-
168-
self.classes_ = unique_y = np.unique(y)
169-
n_classes = unique_y.shape[0]
170-
171-
self.theta_ = np.zeros((n_classes, n_features))
172-
self.sigma_ = np.zeros((n_classes, n_features))
173-
self.class_prior_ = np.zeros(n_classes)
174-
self.class_count_ = np.zeros(n_classes)
175-
epsilon = 1e-9
176-
for i, y_i in enumerate(unique_y):
177-
Xi = X[y == y_i, :]
178-
self.theta_[i, :] = np.mean(Xi, axis=0)
179-
self.sigma_[i, :] = np.var(Xi, axis=0) + epsilon
180-
self.class_count_[i] = Xi.shape[0]
181-
self.class_prior_[:] = self.class_count_ / n_samples
182-
return self
162+
return self._partial_fit(X, y, np.unique(y), _refit=True)
183163

184164
@staticmethod
185165
def _update_mean_variance(n_past, mu, var, X):
@@ -270,10 +250,43 @@ def partial_fit(self, X, y, classes=None):
270250
self : object
271251
Returns self.
272252
"""
253+
return self._partial_fit(X, y, classes, _refit=False)
254+
255+
def _partial_fit(self, X, y, classes=None, _refit=False):
256+
"""Actual implementation of Gaussian NB fitting.
257+
258+
Parameters
259+
----------
260+
X : array-like, shape (n_samples, n_features)
261+
Training vectors, where n_samples is the number of samples and
262+
n_features is the number of features.
263+
264+
y : array-like, shape (n_samples,)
265+
Target values.
266+
267+
classes : array-like, shape (n_classes,)
268+
List of all the classes that can possibly appear in the y vector.
269+
270+
Must be provided at the first call to partial_fit, can be omitted
271+
in subsequent calls.
272+
273+
_refit: bool
274+
If true, act as though this were the first time we called
275+
_partial_fit (ie, throw away any past fitting and start over).
276+
277+
Returns
278+
-------
279+
self : object
280+
Returns self.
281+
"""
282+
273283
X, y = check_arrays(X, y, sparse_format='dense')
274284
y = column_or_1d(y, warn=True)
275285
epsilon = 1e-9
276286

287+
if _refit:
288+
self.classes_ = None
289+
277290
if _check_partial_fit_first_call(self, classes):
278291
# This is the first call to partial_fit:
279292
# initialize various cumulative counters

sklearn/utils/multiclass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ def type_of_target(y):
310310
def _check_partial_fit_first_call(clf, classes=None):
311311
"""Private helper function for factorizing common classes param logic
312312
313-
Estimator that implement the ``partial_fit`` API need to be provided with
314-
the list of possible classes at the first call to partial fit.and
313+
Estimators that implement the ``partial_fit`` API need to be provided with
314+
the list of possible classes at the first call to partial_fit.
315315
316316
Subsequent calls to partial_fit should check that ``classes`` is still
317317
consistent with a previous value of ``clf.classes_`` when provided.

0 commit comments

Comments
 (0)
0