|
38 | 38 | from ..utils.extmath import stable_cumsum
|
39 | 39 | from ..metrics import accuracy_score, r2_score
|
40 | 40 | from ..utils.validation import check_is_fitted
|
| 41 | +from ..utils.validation import _check_sample_weight |
41 | 42 | from ..utils.validation import has_fit_parameter
|
42 | 43 | from ..utils.validation import _num_samples
|
43 | 44 |
|
@@ -117,20 +118,11 @@ def fit(self, X, y, sample_weight=None):
|
117 | 118 |
|
118 | 119 | X, y = self._validate_data(X, y)
|
119 | 120 |
|
120 |
| - if sample_weight is None: |
121 |
| - # Initialize weights to 1 / n_samples |
122 |
| - sample_weight = np.empty(_num_samples(X), dtype=np.float64) |
123 |
| - sample_weight[:] = 1. / _num_samples(X) |
124 |
| - else: |
125 |
| - sample_weight = check_array(sample_weight, ensure_2d=False) |
126 |
| - # Normalize existing weights |
127 |
| - sample_weight = sample_weight / sample_weight.sum(dtype=np.float64) |
128 |
| - |
129 |
| - # Check that the sample weights sum is positive |
130 |
| - if sample_weight.sum() <= 0: |
131 |
| - raise ValueError( |
132 |
| - "Attempting to fit with a non-positive " |
133 |
| - "weighted number of samples.") |
| 121 | + sample_weight = _check_sample_weight(sample_weight, X, np.float64) |
| 122 | + sample_weight /= sample_weight.sum() |
| 123 | + if sample_weight.sum() <= 0: |
| 124 | + raise ValueError("Attempting to fit with a non-positive weighted " |
| 125 | + "number of samples.") |
134 | 126 |
|
135 | 127 | # Check parameters
|
136 | 128 | self._validate_estimator()
|
|
0 commit comments