8000 MNT Standardize sample weights validation in BaseDecisionTree (#15495) · scikit-learn/scikit-learn@bcb8eda · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit bcb8eda

Browse files
fbchowsalliewalecka
authored andcommitted
MNT Standardize sample weights validation in BaseDecisionTree (#15495)
* Standardize sample weights validation in BaseDecisionTree Co-authored-by: Sallie Walecka <sallie.walecka@gmail.com> * Use DOUBLE var instead of np.float64 & refactored test Co-authored-by: Sallie Walecka <sallie.walecka@gmail.com>
1 parent ffc7a47 commit bcb8eda

File tree

2 files changed

+4
-14
lines changed

2 files changed

+4
-14
lines changed

sklearn/tree/_classes.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..utils import Bunch
3333
from ..utils import check_array
3434
from ..utils import check_random_state
35+
from ..utils.validation import _check_sample_weight
3536
from ..utils import compute_sample_weight
3637
from ..utils.multiclass import check_classification_targets
3738
from ..utils.validation import check_is_fitted
@@ -266,18 +267,7 @@ def fit(self, X, y, sample_weight=None, check_input=True,
266267
"or larger than 1").format(max_leaf_nodes))
267268

268269
if sample_weight is not None:
269-
if (getattr(sample_weight, "dtype", None) != DOUBLE or
270-
not sample_weight.flags.contiguous):
271-
sample_weight = np.ascontiguousarray(
272-
sample_weight, dtype=DOUBLE)
273-
if len(sample_weight.shape) > 1:
274-
raise ValueError("Sample weights array has more "
275-
"than one dimension: %d" %
276-
len(sample_weight.shape))
277-
if len(sample_weight) != n_samples:
278-
raise ValueError("Number of weights=%d does not match "
279-
"number of samples=%d" %
280-
(len(sample_weight), n_samples))
270+
sample_weight = _check_sample_weight(sample_weight, X, DOUBLE)
281271

282272
if expanded_class_weight is not None:
283273
if sample_weight is not None:

sklearn/tree/tests/test_tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44
import copy
55
import pickle
6-
from functools import partial
76
from itertools import product
87
import struct
98

@@ -1121,7 +1120,8 @@ def test_sample_weight_invalid():
11211120
clf.fit(X, y, sample_weight=sample_weight)
11221121

11231122
sample_weight = np.array(0)
1124-
with pytest.raises(ValueError):
1123+
expected_err = r"Singleton.* cannot be considered a valid collection"
1124+
with pytest.raises(TypeError, match=expected_err):
11251125
clf.fit(X, y, sample_weight=sample_weight)
11261126

11271127
sample_weight = np.ones(101)

0 commit comments

Comments
 (0)
0