8000 [DOC] Speed up `plot_gradient_boosting_quantile.py` example (#21666) · scikit-learn/scikit-learn@cc534f8 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit cc534f8

Browse files
marenwestermannMaren WestermannTomDLT
authored
[DOC] Speed up plot_gradient_boosting_quantile.py example (#21666)
Co-authored-by: Maren Westermann <maren.westermann@free-now.com> Co-authored-by: Tom Dupré la Tour <tom.dupre-la-tour@m4x.org>
1 parent 3f88680 commit cc534f8

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

examples/ensemble/plot_gradient_boosting_quantile.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def f(x):
6262
all_models = {}
6363
common_params = dict(
6464
learning_rate=0.05,
65-
n_estimators=250,
65+
n_estimators=200,
6666
max_depth=2,
6767
min_samples_leaf=9,
6868
min_samples_split=9,
@@ -97,7 +97,7 @@ def f(x):
9797
fig = plt.figure(figsize=(10, 10))
9898
plt.plot(xx, f(xx), "g:", linewidth=3, label=r"$f(x) = x\,\sin(x)$")
9999
plt.plot(X_test, y_test, "b.", markersize=10, label="Test observations")
100-
plt.plot(xx, y_med, "r-", label="Predicted median", color="orange")
100+
plt.plot(xx, y_med, "r-", label="Predicted median")
101101
plt.plot(xx, y_pred, "r-", label="Predicted mean")
102102
plt.plot(xx, y_upper, "k-")
103103
plt.plot(xx, y_lower, "k-")
@@ -224,25 +224,24 @@ def coverage_fraction(y, y_low, y_high):
224224
# underfit and could not adapt to sinusoidal shape of the signal.
225225
#
226226
# The hyper-parameters of the model were approximately hand-tuned for the
227-
# median regressor and there is no reason than the same hyper-parameters are
227+
# median regressor and there is no reason that the same hyper-parameters are
228228
# suitable for the 5th percentile regressor.
229229
#
230230
# To confirm this hypothesis, we tune the hyper-parameters of a new regressor
231231
# of the 5th percentile by selecting the best model parameters by
232232
# cross-validation on the pinball loss with alpha=0.05:
233233

234234
# %%
235-
from sklearn.model_selection import RandomizedSearchCV
235+
from sklearn.experimental import enable_halving_search_cv # noqa
236+
from sklearn.model_selection import HalvingRandomSearchCV
236237
from sklearn.metrics import make_scorer
237238
from pprint import pprint
238239

239-
240240
param_grid = dict(
241-
learning_rate=[0.01, 0.05, 0.1],
242-
n_estimators=[100, 150, 200, 250, 300],
243-
max_depth=[2, 5, 10, 15, 20],
244-
min_samples_leaf=[1, 5, 10, 20, 30, 50],
245-
min_samples_split=[2, 5, 10, 20, 30, 50],
241+
learning_rate=[0.05, 0.1, 0.2],
242+
max_depth=[2, 5, 10],
243+
min_samples_leaf=[1, 5, 10, 20],
244+
min_samples_split=[5, 10, 20, 30, 50],
246245
)
247246
alpha = 0.05
248247
neg_mean_pinball_loss_05p_scorer = make_scorer(
@@ -251,20 +250,22 @@ def coverage_fraction(y, y_low, y_high):
251250
greater_is_better=False, # maximize the negative loss
252251
)
253252
gbr = GradientBoostingRegressor(loss="quantile", alpha=alpha, random_state=0)
254-
search_05p = RandomizedSearchCV(
253+
search_05p = HalvingRandomSearchCV(
255254
gbr,
256255
param_grid,
257-
n_iter=10, # increase this if computational budget allows
256+
resource="n_estimators",
257+
max_resources=250,
258+
min_resources=50,
258259
scoring=neg_mean_pinball_loss_05p_scorer,
259260
n_jobs=2,
260261
random_state=0,
261262
).fit(X_train, y_train)
262263
pprint(search_05p.best_params_)
263264

264265
# %%
265-
# We observe that the search procedure identifies that deeper trees are needed
266-
# to get a good fit for the 5th percentile regressor. Deeper trees are more
267-
# expressive and less likely to underfit.
266+
# We observe that the hyper-parameters that were hand-tuned for the median
267+
# regressor are in the same range as the hyper-parameters suitable for the 5th
268+
# percentile regressor.
268269
#
269270
# Let's now tune the hyper-parameters for the 95th percentile regressor. We
270271
# need to redefine the `scoring` metric used to select the best model, along
@@ -286,15 +287,14 @@ def coverage_fraction(y, y_low, y_high):
286287
pprint(search_95p.best_params_)
287288

288289
# %%
289-
# This time, shallower trees are selected and lead to a more constant piecewise
290-
# and therefore more robust estimation of the 95th percentile. This is
291-
# beneficial as it avoids overfitting the large outliers of the log-normal
292-
# additive noise.
293-
#
294-
# We can confirm this intuition by displaying the predicted 90% confidence
295-
# interval comprised by the predictions of those two tuned quantile regressors:
296-
# the prediction of the upper 95th percentile has a much coarser shape than the
297-
# prediction of the lower 5th percentile:
290+
# The result shows that the hyper-parameters for the 95th percentile regressor
291+
# identified by the search procedure are roughly in the same range as the hand-
292+
# tuned hyper-parameters for the median regressor and the hyper-parameters
293+
# identified by the search procedure for the 5th percentile regressor. However,
294+
# the hyper-parameter searches did lead to an improved 90% confidence interval
295+
# that is comprised by the predictions of those two tuned quantile regressors.
296+
# Note that the prediction of the upper 95th percentile has a much coarser shape
297+
# than the prediction of the lower 5th percentile because of the outliers:
298298
y_lower = search_05p.predict(xx)
299299
y_upper = search_95p.predict(xx)
300300

0 commit comments

Comments
 (0)
0