8000 [DOC] Speed up `plot_gradient_boosting_quantile.py` example (#21666) · scikit-learn/scikit-learn@b7cf0de · GitHub
[go: up one dir, main page]

Skip to content

Commit b7cf0de

Browse files
marenwestermannMaren WestermannTomDLT
authored andcommitted
[DOC] Speed up plot_gradient_boosting_quantile.py example (#21666)
Co-authored-by: Maren Westermann <maren.westermann@free-now.com> Co-authored-by: Tom Dupré la Tour <tom.dupre-la-tour@m4x.org>
1 parent 96be105 commit b7cf0de

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

examples/ensemble/plot_gradient_boosting_quantile.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def f(x):
6262
all_models = {}
6363
common_params = dict(
6464
learning_rate=0.05,
65-
n_estimators=250,
65+
n_estimators=200,
6666
max_depth=2,
6767
min_samples_leaf=9,
6868
min_samples_split=9,
@@ -97,7 +97,7 @@ def f(x):
9797
fig = plt.figure(figsize=(10, 10))
9898
plt.plot(xx, f(xx), "g:", linewidth=3, label=r"$f(x) = x\,\sin(x)$")
9999
plt.plot(X_test, y_test, "b.", markersize=10, label="Test observations")
100-
plt.plot(xx, y_med, "r-", label="Predicted median", color="orange")
100+
plt.plot(xx, y_med, "r-", label="Predicted median")
101101
plt.plot(xx, y_pred, "r-", label="Predicted mean")
102102
plt.plot(xx, y_upper, "k-")
103103
plt.plot(xx, y_lower, "k-")
@@ -224,25 +224,24 @@ def coverage_fraction(y, y_low, y_high):
224224
# underfit and could not adapt to sinusoidal shape of the signal.
225225
#
226226
# The hyper-parameters of the model were approximately hand-tuned for the
227-
# median regressor and there is no reason than the same hyper-parameters are
227+
# median regressor and there is no reason that the same hyper-parameters are
228228
# suitable for the 5th percentile regressor.
229229
#
230230
# To confirm this hypothesis, we tune the hyper-parameters of a new regressor
231231
# of the 5th percentile by selecting the best model parameters by
232232
# cross-validation on the pinball loss with alpha=0.05:
233233

234234
# %%
235-
from sklearn.model_selection import RandomizedSearchCV
235+
from sklearn.experimental import enable_halving_search_cv # noqa
236+
from sklearn.model_selection import HalvingRandomSearchCV
236237
from sklearn.metrics import make_scorer
237238
from pprint import pprint
238239

239-
240240
param_grid = dict(
241-
learning_rate=[0.01, 0.05, 0.1],
242-
n_estimators=[100, 150, 200, 250, 300],
243-
max_depth=[2, 5, 10, 15, 20],
244-
min_samples_leaf=[1, 5, 10, 20, 30, 50],
245-
min_samples_split=[2, 5, 10, 20, 30, 50],
241+
learning_rate=[0.05, 0.1, 0.2],
242+
max_depth=[2, 5, 10],
243+
min_samples_leaf=[1, 5, 10, 20],
244+
min_samples_split=[5, 10, 20, 30, 50],
246245
)
247246
alpha = 0.05
248247
neg_mean_pinball_loss_05p_scorer = make_scorer(
@@ -251,20 +250,22 @@ def coverage_fraction(y, y_low, y_high):
251250
greater_is_better=False, # maximize the negative loss
252251
)
253252
gbr = GradientBoostingRegressor(loss="quantile", alpha=alpha, random_state=0)
254-
search_05p = RandomizedSearchCV(
253+
search_05p = HalvingRandomSearchCV(
255254
gbr,
256255
param_grid,
257-
n_iter=10, # increase this if computational budget allows
256+
resource="n_estimators",
257+
max_resources=250,
258+
min_resources=50,
258259
scoring=neg_mean_pinball_loss_05p_scorer,
259260
n_jobs=2,
260261
random_state=0,
261262
).fit(X_train, y_train)
262263
pprint(search_05p.best_params_)
263264

264265
# %%
265-
# We observe that the search procedure identifies that deeper trees are needed
266-
# to get a good fit for the 5th percentile regressor. Deeper trees are more
267-
# expressive and less likely to underfit.
266+
# We observe that the hyper-parameters that were hand-tuned for the median
267+
# regressor are in the same range as the hyper-parameters suitable for the 5th
268+
# percentile regressor.
268269
#
269270
# Let's now tune the hyper-parameters for the 95th percentile regressor. We
270271
# need to redefine the `scoring` metric used to select the best model, along
@@ -286,15 +287,14 @@ def coverage_fraction(y, y_low, y_high):
286287
pprint(search_95p.best_params_)
287288

288289
# %%
289-
# This time, shallower trees are selected and lead to a more constant piecewise
290-
# and therefore more robust estimation of the 95th percentile. This is
291-
# beneficial as it avoids overfitting the large outliers of the log-norma 8000 l
292-
# additive noise.
293-
#
294-
# We can confirm this intuition by displaying the predicted 90% confidence
295-
# interval comprised by the predictions of those two tuned quantile regressors:
296-
# the prediction of the upper 95th percentile has a much coarser shape than the
297-
# prediction of the lower 5th percentile:
290+
# The result shows that the hyper-parameters for the 95th percentile regressor
291+
# identified by the search procedure are roughly in the same range as the hand-
292+
# tuned hyper-parameters for the median regressor and the hyper-parameters
293+
# identified by the search procedure for the 5th percentile regressor. However,
294+
# the hyper-parameter searches did lead to an improved 90% confidence interval
295+
# that is comprised by the predictions of those two tuned quantile regressors.
296+
# Note that the prediction of the upper 95th percentile has a much coarser shape
297+
# than the prediction of the lower 5th percentile because of the outliers:
298298
y_lower = search_05p.predict(xx)
299299
y_upper = search_95p.predict(xx)
300300

0 commit comments

Comments
 (0)
0