|
21 | 21 |
|
22 | 22 |
|
23 | 23 | import numpy as np
|
24 |
| -from scipy.sparse import issparse |
25 | 24 |
|
26 | 25 | from .base import BaseEstimator, ClassifierMixin
|
27 | 26 | from .preprocessing import binarize
|
28 | 27 | from .preprocessing import LabelBinarizer
|
29 | 28 | from .preprocessing import label_binarize
|
30 |
| -from .utils import check_X_y, check_array, check_consistent_length |
| 29 | +from .utils import check_X_y, check_array, deprecated |
31 | 30 | from .utils.extmath import safe_sparse_dot
|
32 | 31 | from .utils.fixes import logsumexp
|
33 | 32 | from .utils.multiclass import _check_partial_fit_first_call
|
34 | 33 | from .utils.validation import check_is_fitted, check_non_negative, column_or_1d
|
35 |
| -from .utils import deprecated |
| 34 | +from .utils.validation import _check_sample_weight |
36 | 35 |
|
37 | 36 | __all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'ComplementNB',
|
38 | 37 | 'CategoricalNB']
|
@@ -360,8 +359,7 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
|
360 | 359 | """
|
361 | 360 | X, y = check_X_y(X, y)
|
362 | 361 | if sample_weight is not None:
|
363 |
| - sample_weight = check_array(sample_weight, ensure_2d=False) |
364 |
| - check_consistent_length(y, sample_weight) |
| 362 | + sample_weight = _check_sample_weight(sample_weight, X) |
365 | 363 |
|
366 | 364 | # If the ratio of data variance between dimensions is too small, it
|
367 | 365 | # will cause numerical errors. To address this, we artificially
|
|
0 commit comments