8000 MAINT accelerate plot_partial_dependence.py (#21768) · scikit-learn/scikit-learn@845b1fa · GitHub
[go: up one dir, main page]

Skip to content

Commit 845b1fa

Browse files
authored
MAINT accelerate plot_partial_dependence.py (#21768)
1 parent e44d3cd commit 845b1fa

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

examples/inspection/plot_partial_dependence.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@
8080
est = make_pipeline(
8181
QuantileTransformer(),
8282
MLPRegressor(
83-
hidden_layer_sizes=(50, 50), learning_rate_init=0.01, early_stopping=True
83+
hidden_layer_sizes=(30, 15),
84+
learning_rate_init=0.01,
85+
early_stopping=True,
86+
random_state=0,
8487
),
8588
)
8689
est.fit(X_train, y_train)
@@ -145,7 +148,7 @@
145148

146149
print("Training HistGradientBoostingRegressor...")
147150
tic = time()
148-
est = HistGradientBoostingRegressor()
151+
est = HistGradientBoostingRegressor(random_state=0)
149152
est.fit(X_train, y_train)
150153
print(f"done in {time() - tic:.3f}s")
151154
print(f"Test R2 score: {est.score(X_test, y_test):.2f}")
@@ -233,8 +236,8 @@
233236
X_train,
234237
features,
235238
kind="average",
236-
n_jobs=3,
237-
grid_resolution=20,
239+
n_jobs=2,
240+
grid_resolution=10,
238241
ax=ax,
239242
)
240243
print(f"done in {time() - tic:.3f}s")
@@ -265,12 +268,13 @@
265268

266269
features = ("AveOccup", "HouseAge")
267270
pdp = partial_dependence(
268-
est, X_train, features=features, kind="average", grid_resolution=20
271+
est, X_train, features=features, kind="average", grid_resolution=10
269272
)
270273
XX, YY = np.meshgrid(pdp["values"][0], pdp["values"][1])
271274
Z = pdp.average[0].T
272275
ax = Axes3D(fig)
273276
fig.add_axes(ax)
277+
274278
surf = ax.plot_surface(XX, YY, Z, rstride=1, cstride=1, cmap=plt.cm.BuPu, edgecolor="k")
275279
ax.set_xlabel(features[0])
276280
ax.set_ylabel(features[1])

0 commit comments

Comments
 (0)
0