8000 fix pplot example for uneven grids · scikit-learn/scikit-learn@eb4bd09 · GitHub
[go: up one dir, main page]

Skip to content

Commit eb4bd09

Browse files
trevorstephensogrisel
authored andcommitted
fix pplot example for uneven grids
1 parent 317dea8 commit eb4bd09

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/ensemble/plot_partial_dependence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ def main():
9595
fig = plt.figure()
9696

9797
target_feature = (1, 5)
98-
pdp, (x_axis, y_axis) = partial_dependence(clf, target_feature,
99-
X=X_train, grid_resolution=50)
100-
XX, YY = np.meshgrid(x_axis, y_axis)
101-
Z = pdp.T.reshape(XX.shape).T
98+
pdp, axes = partial_dependence(clf, target_feature,
99+
X=X_train, grid_resolution=50)
100+
XX, YY = np.meshgrid(axes[0], axes[1])
101+
Z = pdp[0].reshape(list(map(np.size, axes))).T
102102
ax = Axes3D(fig)
103103
surf = ax.plot_surface(XX, YY, Z, rstride=1, cstride=1, cmap=plt.cm.BuPu)
104104
ax.set_xlabel(names[target_feature[0]])

0 commit comments

Comments
 (0)
0