8000 DOC Update warm start example in ensemble user guide (#28998) · scikit-learn/scikit-learn@c3d4c51 · GitHub
[go: up one dir, main page]

Skip to content

Commit c3d4c51

Browse files
authored
DOC Update warm start example in ensemble user guide (#28998)
1 parent 61281cf commit c3d4c51

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

doc/modules/ensemble.rst

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,22 @@ fitted model.
603603

604604
::
605605

606-
>>> _ = est.set_params(n_estimators=200, warm_start=True) # set warm_start and new nr of trees
606+
>>> import numpy as np
607+
>>> from sklearn.metrics import mean_squared_error
608+
>>> from sklearn.datasets import make_friedman1
609+
>>> from sklearn.ensemble import GradientBoostingRegressor
610+
611+
>>> X, y = make_friedman1(n_samples=1200, random_state=0, noise=1.0)
612+
>>> X_train, X_test = X[:200], X[200:]
613+
>>> y_train, y_test = y[:200], y[200:]
614+
>>> est = GradientBoostingRegressor(
615+
... n_estimators=100, learning_rate=0.1, max_depth=1, random_state=0,
616+
... loss='squared_error'
617+
... )
618+
>>> est = est.fit(X_train, y_train) # fit with 100 trees
619+
>>> mean_squared_error(y_test, est.predict(X_test))
620+
5.00...
621+
>>> _ = est.set_params(n_estimators=200, warm_start=True) # set warm_start and increase num of trees
607622
>>> _ = est.fit(X_train, y_train) # fit additional 100 trees to est
608623
>>> mean_squared_error(y_test, est.predict(X_test))
609624
3.84...

0 commit comments

Comments
 (0)
0