|
32 | 32 | from ..utils import Bunch
|
33 | 33 | from ..utils import check_array
|
34 | 34 | from ..utils import check_random_state
|
| 35 | +from ..utils.validation import _check_sample_weight |
35 | 36 | from ..utils import compute_sample_weight
|
36 | 37 | from ..utils.multiclass import check_classification_targets
|
37 | 38 | from ..utils.validation import check_is_fitted
|
@@ -266,18 +267,7 @@ def fit(self, X, y, sample_weight=None, check_input=True,
|
266 | 267 | "or larger than 1").format(max_leaf_nodes))
|
267 | 268 |
|
268 | 269 | if sample_weight is not None:
|
269 |
| - if (getattr(sample_weight, "dtype", None) != DOUBLE or |
270 |
| - not sample_weight.flags.contiguous): |
271 |
| - sample_weight = np.ascontiguousarray( |
272 |
| - sample_weight, dtype=DOUBLE) |
273 |
| - if len(sample_weight.shape) > 1: |
274 |
| - raise ValueError("Sample weights array has more " |
275 |
| - "than one dimension: %d" % |
276 |
| - len(sample_weight.shape)) |
277 |
| - if len(sample_weight) != n_samples: |
278 |
| - raise ValueError("Number of weights=%d does not match " |
279 |
| - "number of samples=%d" % |
280 |
| - (len(sample_weight), n_samples)) |
| 270 | + sample_weight = _check_sample_weight(sample_weight, X, DOUBLE) |
281 | 271 |
|
282 | 272 | if expanded_class_weight is not None:
|
283 | 273 | if sample_weight is not None:
|
|
0 commit comments