8000 ENH better error message in HGBRT with feature names (#25092) · scikit-learn/scikit-learn@c557080 · GitHub
[go: up one dir, main page]

Skip to content

Commit c557080

Browse files
ogriseljeremiedbb
andauthored
ENH better error message in HGBRT with feature names (#25092)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent 5e25f8e commit c557080

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,18 +270,21 @@ def _check_categories(self, X):
270270
if missing.any():
271271
categories = categories[~missing]
272272

273+
if hasattr(self, "feature_names_in_"):
274+
feature_name = f"'{self.feature_names_in_[f_idx]}'"
275+
else:
276+
feature_name = f"at index {f_idx}"
277+
273278
if categories.size > self.max_bins:
274279
raise ValueError(
275-
f"Categorical feature at index {f_idx} is "
276-
"expected to have a "
277-
f"cardinality <= {self.max_bins}"
280+
f"Categorical feature {feature_name} is expected to "
281+
f"have a cardinality <= {self.max_bins}"
278282
)
279283

280284
if (categories >= self.max_bins).any():
281285
raise ValueError(
282-
f"Categorical feature at index {f_idx} is "
283-
"expected to be encoded with "
284-
f"values < {self.max_bins}"
286+
f"Categorical feature {feature_name} is expected to "
287+
f"be encoded with values < {self.max_bins}"
285288
)
286289
else:
287290
categories = None

sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,20 +1141,32 @@ def test_categorical_spec_no_categories(Est, categorical_features, as_array):
11411141
@pytest.mark.parametrize(
11421142
"Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor)
11431143
)
1144-
def test_categorical_bad_encoding_errors(Est):
1144+
@pytest.mark.parametrize(
1145+
"use_pandas, feature_name", [(False, "at index 0"), (True, "'f0'")]
1146+
)
1147+
def test_categorical_bad_encoding_errors(Est, use_pandas, feature_name):
11451148
# Test errors when categories are encoded incorrectly
11461149

11471150
gb = Est(categorical_features=[True], max_bins=2)
11481151

1149-
X = np.array([[0, 1, 2]]).T
1152+
if use_pandas:
1153+
pd = pytest.importorskip("pandas")
1154+
X = pd.DataFrame({"f0": [0, 1, 2]})
1155+
else:
1156+
X = np.array([[0, 1, 2]]).T
11501157
y = np.arange(3)
1151-
msg = "Categorical feature at index 0 is expected to have a cardinality <= 2"
1158+
msg = f"Categorical feature {feature_name} is expected to have a cardinality <= 2"
11521159
with pytest.raises(ValueError, match=msg):
11531160
gb.fit(X, y)
11541161

1155-
X = np.array([[0, 2]]).T
1162+
if use_pandas:
1163+
X = pd.DataFrame({"f0": [0, 2]})
1164+
else:
1165+
X = np.array([[0, 2]]).T
11561166
y = np.arange(2)
1157-
msg = "Categorical feature at index 0 is expected to be encoded with values < 2"
1167+
msg = (
1168+
f"Categorical feature {feature_name} is expected to be encoded with values < 2"
1169+
)
11581170
with pytest.raises(ValueError, match=msg):
11591171
gb.fit(X, y)
11601172

0 commit comments

Comments
 (0)
0