8000 FIX use _check_sample_weight to validate sample_weight · scikit-learn/scikit-learn@ede568e · GitHub
[go: up one dir, main page]

Skip to content

Commit ede568e

Browse files
committed
FIX use _check_sample_weight to validate sample_weight
1 parent ea46a7c commit ede568e

File tree

2 files changed

+8
-16
lines changed

2 files changed

+8
-16
lines changed

sklearn/ensemble/weight_boosting.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..utils.extmath import stable_cumsum
3939
from ..metrics import accuracy_score, r2_score
4040
from ..utils.validation import check_is_fitted
41+
from ..utils.validation import _check_sample_weight
4142
from ..utils.validation import has_fit_parameter
4243
from ..utils.validation import _num_samples
4344

@@ -117,20 +118,11 @@ def fit(self, X, y, sample_weight=None):
117118

118119
X, y = self._validate_data(X, y)
119120

120-
if sample_weight is None:
121-
# Initialize weights to 1 / n_samples
122-
sample_weight = np.empty(_num_samples(X), dtype=np.float64)
123-
sample_weight[:] = 1. / _num_samples(X)
124-
else:
125-
sample_weight = check_array(sample_weight, ensure_2d=False)
126-
# Normalize existing weights
127-
sample_weight = sample_weight / sample_weight.sum(dtype=np.float64)
128-
129-
# Check that the sample weights sum is positive
130-
if sample_weight.sum() <= 0:
131-
raise ValueError(
132-
"Attempting to fit with a non-positive "
133-
"weighted number of samples.")
121+
sample_weight = _check_sample_weight(sample_weight, X, np.float64)
122+
sample_weight /= sample_weight.sum()
123+
if sample_weight.sum() <= 0:
124+
raise ValueError("Attempting to fit with a non-positive weighted "
125+
"number of samples.")
134126

135127
# Check parameters
136128
self._validate_estimator()

sklearn/utils/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,8 +1043,8 @@ def _check_sample_weight(sample_weight, X, dtype=None):
10431043
if dtype is None:
10441044
dtype = [np.float64, np.float32]
10451045
sample_weight = check_array(
1046-
sample_weight, accept_sparse=False,
1047-
ensure_2d=False, dtype=dtype, order="C"
1046+
sample_weight, accept_sparse=False, ensure_2d=False, dtype=dtype,
1047+
order="C"
10481048
)
10491049
if sample_weight.ndim != 1:
10501050
raise ValueError("Sample weights must be 1D array or scalar")

0 commit comments

Comments
 (0)
0