8000 Significant speedups to plot_surface function in mplot3d. Thanks to … · ieebon/matplotlib@8ff9020 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8ff9020

Browse files
committed
Significant speedups to plot_surface function in mplot3d. Thanks to Justin Peel!
svn path=/trunk/matplotlib/; revision=8878
1 parent ac35c68 commit 8ff9020

File tree

1 file changed

+32
-33
lines changed

1 file changed

+32
-33
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,6 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
696696
had_data = self.has_data()
697697

698698
rows, cols = Z.shape
699-
tX, tY, tZ = np.transpose(X), np.transpose(Y), np.transpose(Z)
700699
rstride = kwargs.pop('rstride', 10)
701700
cstride = kwargs.pop('cstride', 10)
702701

@@ -719,35 +718,36 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
719718
fcolors = self._shade_colors_lightsource(Z, cmap, lightsource)
720719

721720
polys = []
722-
normals = []
721+
# Only need these vectors to shade if there is no cmap
722+
if cmap is None and shade :
723+
totpts = int(np.ceil(float(rows - 1) / rstride) *
724+
np.ceil(float(cols - 1) / cstride))
725+
v1 = np.empty((totpts, 3))
726+
v2 = np.empty((totpts, 3))
727+
# This indexes the vertex points
728+
which_pt = 0
729+
730+
723731
#colset contains the data for coloring: either average z or the facecolor
724732
colset = []
725-
for rs in np.arange(0, rows-1, rstride):
726-
for cs in np.arange(0, cols-1, cstride):
733+
for rs in xrange(0, rows-1, rstride):
734+
for cs in xrange(0, cols-1, cstride):
727735
ps = []
728-
corners = []
729-
for a, ta in [(X, tX), (Y, tY), (Z, tZ)]:
730-
ztop = a[rs][cs:min(cols, cs+cstride+1)]
731-
zleft = ta[min(cols-1, cs+cstride)][rs:min(rows, rs+rstride+1)]
732-
zbase = a[min(rows-1, rs+rstride)][cs:min(cols, cs+cstride+1):]
733-
zbase = zbase[::-1]
734-
zright = ta[cs][rs:min(rows, rs+rstride+1):]
735-
zright = zright[::-1]
736-
corners.append([ztop[0], ztop[-1], zbase[0], zbase[-1]])
736+
for a in (X, Y, Z) :
737+
ztop = a[rs,cs:min(cols, cs+cstride+1)]
738+
zleft = a[rs+1:min(rows, rs+rstride+1),
739+
min(cols-1, cs+cstride)]
740+
zbase = a[min(rows-1, rs+rstride), cs:min(cols, cs+cstride+1):][::-1]
741+
zright = a[rs:min(rows-1, rs+rstride):, cs][::-1]
737742
z = np.concatenate((ztop, zleft, zbase, zright))
738743
ps.append(z)
739744

740745
# The construction leaves the array with duplicate points, which
741746
# are removed here.
742747
ps = zip(*ps)
743748
lastp = np.array([])
744-
ps2 = []
745-
avgzsum = 0.0
746-
for p in ps:
747-
if p != lastp:
748-
ps2.append(p)
749-
lastp = p
750-
avgzsum += p[2]
749+
ps2 = [ps[0]] + [ps[i] for i in xrange(1, len(ps)) if ps[i] != ps[i-1]]
750+
avgzsum = sum(p[2] for p in ps2)
751751
polys.append(ps2)
752752

753753
if fcolors is not None:
@@ -758,9 +758,13 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
758758
# Only need vectors to shade if no cmap
759759
if cmap is None and shade:
760760
i1, i2, i3 = 0, int(len(ps2)/3), int(2*len(ps2)/3)
761-
v1 = np.array(ps2[i1]) - np.array(ps2[i2])
762-
v2 = np.array(ps2[i2]) - np.array(ps2[i3])
763-
normals.append(np.cross(v1, v2))
761+
v1[which_pt] = np.array(ps2[i1]) - np.array(ps2[i2])
762+
v2[which_pt] = np.array(ps2[i2]) - np.array(ps2[i3])
763+
which_pt += 1
764+
if cmap is None and shade:
765+
normals = np.cross(v1, v2)
766+
else :
767+
normals = []
764768

765769
polyc = art3d.Poly3DCollection(polys, *args, **kwargs)
766770

@@ -808,24 +812,19 @@ def _shade_colors(self, color, normals):
808812
*color* can also be an array of the same length as *normals*.
809813
'''
810814

811-
shade = []
812-
for n in normals:
813-
n = n / proj3d.mod(n)
814-
shade.append(np.dot(n, [-1, -1, 0.5]))
815-
816-
shade = np.array(shade)
815+
shade = np.array([np.dot(n / proj3d.mod(n), [-1, -1, 0.5])
816+
for n in normals])
817817
mask = ~np.isnan(shade)
818818

819819
if len(shade[mask]) > 0:
820820
norm = Normalize(min(shade[mask]), max(shade[mask]))
821821
if art3d.iscolor(color):
822822
color = color.copy()
823823
color[3] = 1
824-
colors = [color * (0.5 + norm(v) * 0.5) for v in shade]
824+
colors = np.outer(0.5 + norm(shade) * 0.5, color)
825825
else:
826-
colors = [np.array(colorConverter.to_rgba(c)) * \
827-
(0.5 + norm(v) * 0.5) \
828-
for c, v in zip(color, shade)]
826+
colors = colorConverter.to_rgba_array(color) * \
827+
(0.5 + 0.5 * norm(shade)[:, np.newaxis])
829828
else:
830829
colors = color.copy()
831830

0 commit comments

Comments
 (0)
0