diff --git a/sklearn/utils/class_weight.py b/sklearn/utils/class_weight.py index 9644fb6102bc5..b2ba15807c2fc 100644 --- a/sklearn/utils/class_weight.py +++ b/sklearn/utils/class_weight.py @@ -42,6 +42,9 @@ def compute_class_weight(class_weight, classes, y): # Import error caused by circular imports. from ..preprocessing import LabelEncoder + if set(y) - set(classes): + raise ValueError("classes should include all valid labels that can " + "be in y") if class_weight is None or len(class_weight) == 0: # uniform class weights weight = np.ones(classes.shape[0], dtype=np.float64, order='C') diff --git a/sklearn/utils/tests/test_class_weight.py b/sklearn/utils/tests/test_class_weight.py index 66abac2c5b39f..7b6eb82b51fde 100644 --- a/sklearn/utils/tests/test_class_weight.py +++ b/sklearn/utils/tests/test_class_weight.py @@ -37,6 +37,11 @@ def test_compute_class_weight_not_present(): y = np.asarray([0, 0, 0, 1, 1, 2]) assert_raises(ValueError, compute_class_weight, "auto", classes, y) assert_raises(ValueError, compute_class_weight, "balanced", classes, y) + # Raise error when y has items not in classes + classes = np.arange(2) + assert_raises(ValueError, compute_class_weight, "auto", classes, y) + assert_raises(ValueError, compute_class_weight, "balanced", classes, y) + assert_raises(ValueError, compute_class_weight, {0: 1., 1: 2.}, classes, y) def test_compute_class_weight_dict():