8000 Added rcount/ccount to plot_surface(), providing an alternative to rs… · matplotlib/matplotlib@58d3de2 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 58d3de2

Browse files
committed
Added rcount/ccount to plot_surface(), providing an alternative to rstride/cstride
1 parent 984e9b0 commit 58d3de2

File tree

7 files changed

+117
-17
lines changed

7 files changed

+117
-17
lines changed

doc/users/whats_new/plot_surface.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
`rcount` and `ccount` for `plot_surface()`
2+
------------------------------------------
3+
4+
As of v2.0, mplot3d's :func:`~mpl_toolkits.mplot3d.axes3d.plot_surface` now
5+
accepts `rcount` and `ccount` arguments for controlling the sampling of the
6+
input data for plotting. These arguments specify the maximum number of
7+
evenly spaced samples to take from the input data. These arguments are
8+
also the new default sampling method for the function, and is
9+
considered a style change.
10+
11+
The old `rstride` and `cstride` arguments, which specified the size of the
12+
evenly spaced samples, become the default when 'classic' mode is invoked,
13+
and are still available for use. There are no plans for deprecating these
14+
arguments.
15+

examples/mplot3d/surface3d_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
Z = np.sin(R)
2929

3030
# Plot the surface.
31-
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm,
31+
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
3232
linewidth=0, antialiased=False)
3333

3434
# Customize the z axis.

examples/mplot3d/surface3d_demo2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@
2222
z = 10 * np.outer(np.ones(np.size(u)), np.cos(v))
2323

2424
# Plot the surface
25-
ax.plot_surface(x, y, z, rstride=4, cstride=4, color='b')
25+
ax.plot_surface(x, y, z, color='b')
2626

2727
plt.show()

examples/mplot3d/surface3d_demo3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
colors[x, y] = colortuple[(x + y) % len(colortuple)]
3535

3636
# Plot the surface with face colors taken from the array we made.
37-
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=colors,
38-
linewidth=0)
37+
surf = ax.plot_surface(X, Y, Z, facecolors=colors, linewidth=0)
3938

4039
# Customize the z axis.
4140
ax.set_zlim(-1, 1)

examples/mplot3d/surface3d_radial_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
X, Y = R*np.cos(P), R*np.sin(P)
2929

3030
# Plot the surface.
31-
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=plt.cm.YlGnBu_r)
31+
ax.plot_surface(X, Y, Z, cmap=plt.cm.YlGnBu_r)
3232

3333
# Tweak the limits and add latex math labels.
3434
ax.set_zlim(0, 1)

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,15 +1553,28 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
15531553
15541554
The `rstride` and `cstride` kwargs set the stride used to
15551555
sample the input data to generate the graph. If 1k by 1k
1556-
arrays are passed in the default values for the strides will
1557-
result in a 100x100 grid being plotted.
1556+
arrays are passed in, the default values for the strides will
1557+
result in a 100x100 grid being plotted. Defaults to 10.
1558+
Raises a ValueError if both stride and count kwargs
1559+
(see next section) are provided.
1560+
1561+
The `rcount` and `ccount` kwargs supersedes `rstride` and
1562+
`cstride` for default sampling method for surface plotting.
1563+
These arguments will determine at most how many evenly spaced
1564+
samples will be taken from the input data to generate the graph.
1565+
This is the default sampling method unless using the 'classic'
1566+
style. Will raise ValueError if both stride and count are
1567+
specified.
1568+
Added in v2.0.0.
15581569
15591570
============= ================================================
15601571
Argument Description
15611572
============= ================================================
15621573
*X*, *Y*, *Z* Data values as 2D arrays
1563-
*rstride* Array row stride (step size), defaults to 10
1564-
*cstride* Array column stride (step size), defaults to 10
1574+
*rstride* Array row stride (step size)
1575+
*cstride* Array column stride (step size)
1576+
*rcount* Use at most this many rows, defaults to 50
1577+
*ccount* Use at most this many columns, defaults to 50
15651578
*color* Color of the surface patches
15661579
*cmap* A colormap for the surface patches.
15671580
*facecolors* Face colors for the individual patches
@@ -1582,8 +1595,30 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
15821595
X, Y, Z = np.broadcast_arrays(X, Y, Z)
15831596
rows, cols = Z.shape
15841597

1598+
has_stride = 'rstride' in kwargs or 'cstride' in kwargs
1599+
has_count = 'rcount' in kwargs or 'ccount' in kwargs
1600+
1601+
if has_stride and has_count:
1602+
raise ValueError("Cannot specify both stride and count arguments")
1603+
15851604
rstride = kwargs.pop('rstride', 10)
15861605
cstride = kwargs.pop('cstride', 10)
1606+
rcount = kwargs.pop('rcount', 50)
1607+
ccount = kwargs.pop('ccount', 50)
1608+
1609+
if rcParams['_internal.classic_mode']:
1610+
# Strides have priority over counts in classic mode.
1611+
# So, only compute strides from counts
1612+
# if counts were explicitly given
1613+
if has_count:
1614+
rstride = int(np.ceil(rows / rcount))
1615+
cstride = int(np.ceil(cols / ccount))
1616+
else:
1617+
# If the strides are provided then it has priority.
1618+
# Otherwise, compute the strides from the counts.
1619+
if not has_stride:
1620+
rstride = int(np.ceil(rows / rcount))
1621+
cstride = int(np.ceil(cols / ccount))
15871622

15881623
if 'facecolors' in kwargs:
15891624
fcolors = kwargs.pop('facecolors')
@@ -1733,7 +1768,21 @@ def plot_wireframe(self, X, Y, Z, *args, **kwargs):
17331768
The `rstride` and `cstride` kwargs set the stride used to
17341769
sample the input data to generate the graph. If either is 0
17351770
the input data in not sampled along this direction producing a
1736-
3D line plot rather than a wireframe plot.
1771+
3D line plot rather than a wireframe plot. The stride arguments
1772+
are only used by default if in the 'classic' mode. They are
1773+
now superseded by `rcount` and `ccount`. Will raise ValueError
1774+
if both stride and count are used.
1775+
1776+
` The `rcount` and `ccount` kwargs supersedes `rstride` and
1777+
`cstride` for default sampling method for wireframe plotting.
1778+
These arguments will determine at most how many evenly spaced
1779+
samples will be taken from the input data to generate the graph.
1780+
This is the default sampling method unless using the 'classic'
1781+
style. Will raise ValueError if both stride and count are
1782+
specified. If either is zero, then the input data is not sampled
1783+
along this direction, producing a 3D line plot rather than a
1784+
wireframe plot.
1785+
Added in v2.0.0.
17371786
17381787
========== ================================================
17391788
Argument Description
@@ -1742,6 +1791,8 @@ def plot_wireframe(self, X, Y, Z, *args, **kwargs):
17421791
*Z*
17431792
*rstride* Array row stride (step size), defaults to 1
17441793
*cstride* Array column stride (step size), defaults to 1
1794+
*rcount* Use at most this many rows, defaults to 50
1795+
*ccount* Use at most this many columns, defaults to 50
17451796
========== ================================================
17461797
17471798
Keyword arguments are passed on to
@@ -1750,15 +1801,37 @@ def plot_wireframe(self, X, Y, Z, *args, **kwargs):
17501801
Returns a :class:`~mpl_toolkits.mplot3d.art3d.Line3DCollection`
17511802
'''
17521803

1753-
rstride = kwargs.pop("rstride", 1)
1754-
cstride = kwargs.pop("cstride", 1)
1755-
17561804
had_data = self.has_data()
17571805
Z = np.atleast_2d(Z)
17581806
# FIXME: Support masked arrays
17591807
X, Y, Z = np.broadcast_arrays(X, Y, Z)
17601808
rows, cols = Z.shape
17611809

1810+
has_stride = 'rstride' in kwargs or 'cstride' in kwargs
1811+
has_count = 'rcount' in kwargs or 'ccount' in kwargs
1812+
1813+
if has_stride and has_count:
1814+
raise ValueError("Cannot specify both stride and count arguments")
1815+
1816+
rstride = kwargs.pop('rstride', 1)
1817+
cstride = kwargs.pop('cstride', 1)
1818+
rcount = kwargs.pop('rcount', 50)
1819+
ccount = kwargs.pop('ccount', 50)
1820+
1821+
if rcParams['_internal.classic_mode']:
1822+
# Strides have priority over counts in classic mode.
1823+
# So, only compute strides from counts
1824+
# if counts were explicitly given
1825+
if has_count:
1826+
rstride = int(np.ceil(rows / rcount)) if rcount else 0
1827+
cstride = int(np.ceil(cols / ccount)) if ccount else 0
1828+
else:
1829+
# If the strides are provided then it has priority.
1830+
# Otherwise, compute the strides from the counts.
1831+
if not has_stride:
1832+
rstride = int(np.ceil(rows / rcount)) if rcount else 0
1833+
cstride = int(np.ceil(cols / ccount)) if ccount else 0
1834+
17621835
# We want two sets of lines, one running along the "rows" of
17631836
# Z and another set of lines running along the "columns" of Z.
17641837
# This transpose will make it easy to obtain the columns.

lib/mpl_toolkits/tests/test_mplot3d.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def f(t):
105105
R = np.sqrt(X ** 2 + Y ** 2)
106106
Z = np.sin(R)
107107

108-
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
108+
surf = ax.plot_surface(X, Y, Z, rcount=40, ccount=40,
109109
linewidth=0, antialiased=False)
110110

111111
ax.set_zlim3d(-1, 1)
@@ -141,7 +141,7 @@ def test_surface3d():
141141
X, Y = np.meshgrid(X, Y)
142142
R = np.sqrt(X ** 2 + Y ** 2)
143143
Z = np.sin(R)
144-
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm,
144+
surf = ax.plot_surface(X, Y, Z, rcount=40, ccount=40, cmap=cm.coolwarm,
145145
lw=0, antialiased=False)
146146
ax.set_zlim(-1.01, 1.01)
147147
fig.colorbar(surf, shrink=0.5, aspect=5)
@@ -194,7 +194,7 @@ def test_wireframe3d():
194194
fig = plt.figure()
195195
ax = fig.add_subplot(111, projection='3d')
196196
X, Y, Z = axes3d.get_test_data(0.05)
197-
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10)
197+
ax.plot_wireframe(X, Y, Z, rcount=13, ccount=13)
198198

199199

200200
@image_comparison(baseline_images=['wireframe3dzerocstride'], remove_text=True,
@@ -203,7 +203,7 @@ def test_wireframe3dzerocstride():
203203
fig = plt.figure()
204204
ax = fig.add_subplot(111, projection='3d')
205205
X, Y, Z = axes3d.get_test_data(0.05)
206-
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=0)
206+
ax.plot_wireframe(X, Y, Z, rcount=13, ccount=0)
207207

208208

209209
@image_comparison(baseline_images=['wireframe3dzerorstride'], remove_text=True,
@@ -214,6 +214,7 @@ def test_wireframe3dzerorstride():
214214
X, Y, Z = axes3d.get_test_data(0.05)
215215
ax.plot_wireframe(X, Y, Z, rstride=0, cstride=10)
216216

217+
217218
@cleanup
218219
def test_wireframe3dzerostrideraises():
219220
fig = plt.figure()
@@ -222,6 +223,18 @@ def test_wireframe3dzerostrideraises():
222223
with assert_raises(ValueError):
223224
ax.plot_wireframe(X, Y, Z, rstride=0, cstride=0)
224225

226+
227+
@cleanup
228+
def test_mixedsamplesraises():
229+
fig = plt.figure()
230+
ax = fig.add_subplot(111, projection='3d')
231+
X, Y, Z = axes3d.get_test_data(0.05)
232+
with assert_raises(ValueError):
233+
ax.plot_wireframe(X, Y, Z, rstride=10, ccount=50)
234+
with assert_raises(ValueError):
235+
ax.plot_surface(X, Y, Z, cstride=50, rcount=10)
236+
237+
225238
@image_comparison(baseline_images=['quiver3d'], remove_text=True)
226239
def test_quiver3d():
227240
fig = plt.figure()

0 commit comments

Comments
 (0)
0