8000 fix for missing classes found in y #4327 · mannby/scikit-learn@906073d · GitHub
[go: up one dir, main page]

Skip to content

Commit 906073d

Browse files
trevorstephensClaes-Fredrik Mannby
authored andcommitted
fix for missing classes found in y scikit-learn#4327
1 parent c4db372 commit 906073d

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

sklearn/utils/class_weight.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def compute_class_weight(class_weight, classes, y):
4242
# Import error caused by circular imports.
4343
from ..preprocessing import LabelEncoder
4444

45+
if set(y) - set(classes):
46+
raise ValueError("classes should include all valid labels that can "
47+
"be in y")
4548
if class_weight is None or len(class_weight) == 0:
4649
# uniform class weights
4750
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')

sklearn/utils/tests/test_class_weight.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def test_compute_class_weight_not_present():
3737
y = np.asarray([0, 0, 0, 1, 1, 2])
3838
assert_raises(ValueError, compute_class_weight, "auto", classes, y)
3939
assert_raises(ValueError, compute_class_weight, "balanced", classes, y)
40+
# Raise error when y has items not in classes
41+
classes = np.arange(2)
42+
assert_raises(ValueError, compute_class_weight, "auto", classes, y)
43+
assert_raises(ValueError, compute_class_weight, "balanced", classes, y)
44+
assert_raises(ValueError, compute_class_weight, {0: 1., 1: 2.}, classes, y)
4045

4146

4247
def test_compute_class_weight_dict():

0 commit comments

Comments
 (0)
0