8000 ENH use utility _check_sample_weight in _BaseDiscreteNB (#16263) · thomasjpfan/scikit-learn@fd8e738 · GitHub
[go: up one dir, main page]

Skip to content

Commit fd8e738

Browse files
Batalexthomasjpfan
authored andcommitted
ENH use utility _check_sample_weight in _BaseDiscreteNB (scikit-learn#16263)
1 parent d05ccaf commit fd8e738

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

sklearn/naive_bayes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,9 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
569569
# We convert it to np.float64 to support sample_weight consistently
570570
Y = Y.astype(np.float64, copy=False)
571571
if sample_weight is not None:
572+
sample_weight = _check_sample_weight(sample_weight, X)
572573
sample_weight = np.atleast_2d(sample_weight)
573-
Y *= check_array(sample_weight).T
574+
Y *= sample_weight.T
574575

575576
class_prior = self.class_prior
576577

@@ -621,9 +622,9 @@ def fit(self, X, y, sample_weight=None):
621622
# this means we also don't have to cast X to floating point
622623
if sample_weight is not None:
623624
Y = Y.astype(np.float64, copy=False)
624-
sample_weight = np.asarray(sample_weight)
625+
sample_weight = _check_sample_weight(sample_weight, X)
625626
sample_weight = np.atleast_2d(sample_weight)
626-
Y *= check_array(sample_weight).T
627+
Y *= sample_weight.T
627628

628629
class_prior = self.class_prior
629630

0 commit comments

Comments
 (0)
0