8000 TST add test for quantile HGBT · scikit-learn/scikit-learn@57505de · GitHub
[go: up one dir, main page]

Skip to content

Commit 57505de

Browse files
committed
TST add test for quantile HGBT
1 parent d61d2f4 commit 57505de

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
HalfMultinomialLoss,
88
HalfPoissonLoss,
99
HalfSquaredError,
10+
PinballLoss,
1011
)
1112
from sklearn.datasets import make_classification, make_regression
1213
from sklearn.datasets import make_low_rank_matrix
@@ -35,6 +36,7 @@
3536
"squared_error": HalfSquaredError,
3637
"absolute_error&quo 10000 t;: AbsoluteError,
3738
"poisson": HalfPoissonLoss,
39+
"quantile": PinballLoss,
3840
"binary_crossentropy": HalfBinomialLoss,
3941
"categorical_crossentropy": HalfMultinomialLoss,
4042
}
@@ -249,6 +251,40 @@ def test_absolute_error_sample_weight():
249251
gbdt.fit(X, y, sample_weight=sample_weight)
250252

251253

254+
@pytest.mark.parametrize("quantile", [0.2, 0.5, 0.8])
255+
def test_asymmetric_error(quantile):
256+
"""Test quantile regression for asymmetric distributed targets."""
257+
n_samples = 10_000
258+
rng = np.random.RandomState(42)
259+
# take care that X @ coef + intercept > 0
260+
X = np.concatenate(
261+
(
262+
np.abs(rng.randn(n_samples)[:, None]),
263+
-rng.randint(2, size=(n_samples, 1)),
264+
),
265+
axis=1,
266+
)
267+
intercept = 1.23
268+
coef = np.array([0.5, -2])
269+
# For an exponential distribution with rate lambda, e.g. exp(-lambda * x),
270+
# the quantile at level q is:
271+
# quantile(q) = - log(1 - q) / lambda
272+
# scale = 1/lambda = -quantile(q) / log(1-q)
273+
y = rng.exponential(
274+
scale=-(X @ coef + intercept) / np.log(1 - quantile), size=n_samples
275+
)
276+
model = HistGradientBoostingRegressor(
277+
loss="quantile",
278+
loss_param=quantile,
279+
).fit(X, y)
280+
assert_allclose(np.mean(model.predict(X) > y), quantile, rtol=1e-2)
281+
282+
loss_true_quantile = model._loss(y, X @ coef + intercept)
283+
loss_pred_quantile = model._loss(y, model.predict(X))
284+
# we are overfitting
285+
assert loss_pred_quantile <= loss_true_quantile
286+
287+
252288
@pytest.mark.parametrize("y", [([1.0, -2.0, 0.0]), ([0.0, 0.0, 0.0])])
253289
def test_poisson_y_positive(y):
254290
# Test that ValueError is raised if either one y_i < 0 or sum(y_i) <= 0.

0 commit comments

Comments
 (0)
0