@@ -159,27 +159,7 @@ def fit(self, X, y):
159
159
self : object
160
160
Returns self.
161
161
"""
162
-
163
- X , y = check_arrays (X , y , sparse_format = 'dense' )
164
- y = column_or_1d (y , warn = True )
165
-
166
- n_samples , n_features = X .shape
167
-
168
- self .classes_ = unique_y = np .unique (y )
169
- n_classes = unique_y .shape [0 ]
170
-
171
- self .theta_ = np .zeros ((n_classes , n_features ))
172
- self .sigma_ = np .zeros ((n_classes , n_features ))
173
- self .class_prior_ = np .zeros (n_classes )
174
- self .class_count_ = np .zeros (n_classes )
175
- epsilon = 1e-9
176
- for i , y_i in enumerate (unique_y ):
177
- Xi = X [y == y_i , :]
178
- self .theta_ [i , :] = np .mean (Xi , axis = 0 )
179
- self .sigma_ [i , :] = np .var (Xi , axis = 0 ) + epsilon
180
- self .class_count_ [i ] = Xi .shape [0 ]
181
- self .class_prior_ [:] = self .class_count_ / n_samples
182
- return self
162
+ return self ._partial_fit (X , y , np .unique (y ), _refit = True )
183
163
184
164
@staticmethod
185
165
def _update_mean_variance (n_past , mu , var , X ):
@@ -270,10 +250,43 @@ def partial_fit(self, X, y, classes=None):
270
250
self : object
271
251
Returns self.
272
252
"""
253
+ return self ._partial_fit (X , y , classes , _refit = False )
254
+
255
+ def _partial_fit (self , X , y , classes = None , _refit = False ):
256
+ """Actual implementation of Gaussian NB fitting.
257
+
258
+ Parameters
259
+ ----------
260
+ X : array-like, shape (n_samples, n_features)
261
+ Training vectors, where n_samples is the number of samples and
262
+ n_features is the number of features.
263
+
264
+ y : array-like, shape (n_samples,)
265
+ Target values.
266
+
267
+ classes : array-like, shape (n_classes,)
268
+ List of all the classes that can possibly appear in the y vector.
269
+
270
+ Must be provided at the first call to partial_fit, can be omitted
271
+ in subsequent calls.
272
+
273
+ _refit: bool
274
+ If true, act as though this were the first time we called
275
+ _partial_fit (ie, throw away any past fitting and start over).
276
+
277
+ Returns
278
+ -------
279
+ self : object
280
+ Returns self.
281
+ """
282
+
273
283
X , y = check_arrays (X , y , sparse_format = 'dense' )
274
284
y = column_or_1d (y , warn = True )
275
285
epsilon = 1e-9
276
286
287
+ if _refit :
288
+ self .classes_ = None
289
+
277
290
if _check_partial_fit_first_call (self , classes ):
278
291
# This is the first call to partial_fit:
279
292
# initialize various cumulative counters
0 commit comments