8000 MAINT Use _check_sample_weight in GaussianNB (#15480) · scikit-learn/scikit-learn@6b6fa5d · GitHub
[go: up one dir, main page]

Skip to content

Commit 6b6fa5d

Browse files
Marie Douriezrth
Marie Douriez
authored andcommitted
MAINT Use _check_sample_weight in GaussianNB (#15480)
1 parent 452d919 commit 6b6fa5d

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

sklearn/naive_bayes.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,17 @@
2121

2222

2323
import numpy as np
24-
from scipy.sparse import issparse
2524

2625
from .base import BaseEstimator, ClassifierMixin
2726
from .preprocessing import binarize
2827
from .preprocessing import LabelBinarizer
2928
from .preprocessing import label_binarize
30-
from .utils import check_X_y, check_array, check_consistent_length
29+
from .utils import check_X_y, check_array, deprecated
3130
from .utils.extmath import safe_sparse_dot
3231
from .utils.fixes import logsumexp
3332
from .utils.multiclass import _check_partial_fit_first_call
3433
from .utils.validation import check_is_fitted, check_non_negative, column_or_1d
35-
from .utils import deprecated
34+
from .utils.validation import _check_sample_weight
3635

3736
__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'ComplementNB',
3837
'CategoricalNB']
@@ -360,8 +359,7 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
360359
"""
361360
X, y = check_X_y(X, y)
362361
if sample_weight is not None:
363-
sample_weight = check_array(sample_weight, ensure_2d=False)
364-
check_consistent_length(y, sample_weight)
362+
sample_weight = _check_sample_weight(sample_weight, X)
365363

366364
# If the ratio of data variance between dimensions is too small, it
367365
# will cause numerical errors. To address this, we artificially

0 commit comments

Comments
 (0)
0