8000 ENH use utility _check_sample_weight in _BaseDiscreteNB (#16263) · scikit-learn/scikit-learn@84628b0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 84628b0

Browse files
authored
ENH use utility _check_sample_weight in _BaseDiscreteNB (#16263)
1 parent 20a431f commit 84628b0

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

sklearn/naive_bayes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,9 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
569569
# We convert it to np.float64 to support sample_weight consistently
570570
Y = Y.astype(np.float64, copy=False)
571571
if sample_weight is not None:
572+
sample_weight = _check_sample_weight(sample_weight, X)
572573
sample_weight = np.atleast_2d(sample_weight)
573-
Y *= check_array(sample_weight).T
574+
Y *= sample_weight.T
574575

575576
class_prior = self.class_prior
576577

@@ -621,9 +622,9 @@ def fit(self, X, y, sample_weight=None):
621622
# this means we also don't have to cast X to floating point
622623
if sample_weight is not None:
623624
Y = Y.astype(np.float64, copy=False)
624-
sample_weight = np.asarray(sample_weight)
625+
sample_weight = _check_sample_weight(sample_weight, X)
625626
sample_weight = np.atleast_2d(sample_weight)
626-
Y *= check_array(sample_weight).T
627+
Y *= sample_weight.T
627628

628629
class_prior = self.class_prior
629630

0 commit comments

Comments
 (0)
0