8000 Fix argument checking in `Axes3D.quiver` by oscargus · Pull Request #24862 · matplotlib/matplotlib · GitHub
[go: up one dir, main page]

Skip to content

Fix argument checking in Axes3D.quiver #24862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 18 additions & 27 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,7 +1561,7 @@ def plot_surface(self, X, Y, Z, *, norm=None, vmin=None,
The lightsource to use when *shade* is True.

**kwargs
Other arguments are forwarded to `.Poly3DCollection`.
Other keyword arguments are forwarded to `.Poly3DCollection`.
"""

had_data = self.has_data()
Expand Down Expand Up @@ -1724,7 +1724,7 @@ def plot_wireframe(self, X, Y, Z, **kwargs):
of the new default of ``rcount = ccount = 50``.

**kwargs
Other arguments are forwarded to `.Line3DCollection`.
Other keyword arguments are forwarded to `.Line3DCollection`.
"""

had_data = self.has_data()
Expand Down Expand Up @@ -1851,7 +1851,7 @@ def plot_trisurf(self, *args, color=None, norm=None, vmin=None, vmax=None,
lightsource : `~matplotlib.colors.LightSource`
The lightsource to use when *shade* is True.
**kwargs
All other arguments are passed on to
All other keyword arguments are passed on to
:class:`~mpl_toolkits.mplot3d.art3d.Poly3DCollection`

Examples
Expand Down Expand Up @@ -2252,7 +2252,7 @@ def scatter(self, xs, ys, zs=0, zdir='z', s=20, c=None, depthshade=True,
data : indexable object, optional
DATA_PARAMETER_PLACEHOLDER
**kwargs
All other arguments are passed on to `~.axes.Axes.scatter`.
All other keyword arguments are passed on to `~.axes.Axes.scatter`.

Returns
-------
Expand Down Expand Up @@ -2304,7 +2304,8 @@ def bar(self, left, height, zs=0, zdir='z', *args, **kwargs):
data : indexable object, optional
DATA_PARAMETER_PLACEHOLDER
**kwargs
Other arguments are forwarded to `matplotlib.axes.Axes.bar`.
Other keyword arguments are forwarded to
`matplotlib.axes.Axes.bar`.

Returns
-------
Expand Down Expand Up @@ -2508,19 +2509,16 @@ def set_title(self, label, fontdict=None, loc='center', **kwargs):
return ret

@_preprocess_data()
def quiver(self, *args,
def quiver(self, X, Y, Z, U, V, W, *,
length=1, arrow_length_ratio=.3, pivot='tail', normalize=False,
**kwargs):
"""
ax.quiver(X, Y, Z, U, V, W, /, length=1, arrow_length_ratio=.3, \
pivot='tail', normalize=False, **kwargs)

Plot a 3D field of arrows.

The arguments could be array-like or scalars, so long as they
they can be broadcast together. The arguments can also be
masked arrays. If an element in any of argument is masked, then
that corresponding quiver element will not be plotted.
The arguments can be array-like or scalars, so long as they can be
broadcast together. The arguments can also be masked arrays. If an
element in any of argument is masked, then that corresponding quiver
element will not be plotted.

Parameters
----------
Expand Down Expand Up @@ -2550,7 +2548,7 @@ def quiver(self, *args,

**kwargs
Any additional keyword arguments are delegated to
:class:`~matplotlib.collections.LineCollection`
:class:`.Line3DCollection`
"""

def calc_arrows(UVW, angle=15):
Expand Down Expand Up @@ -2581,22 +2579,15 @@ def calc_arrows(UVW, angle=15):

had_data = self.has_data()

# handle args
argi = 6
if len(args) < argi:
raise ValueError('Wrong number of arguments. Expected %d got %d' %
(argi, len(args)))

# first 6 arguments are X, Y, Z, U, V, W
input_args = args[:argi]
input_args = [X, Y, Z, U, V, W]

# extract the masks, if any
masks = [k.mask for k in input_args
if isinstance(k, np.ma.MaskedArray)]
# broadcast to match the shape
bcast = np.broadcast_arrays(*input_args, *masks)
input_args = bcast[:argi]
masks = bcast[argi:]
input_args = bcast[:6]
masks = bcast[6:]
if masks:
# combine the masks into one
mask = functools.reduce(np.logical_or, masks)
Expand All @@ -2608,7 +2599,7 @@ def calc_arrows(UVW, angle=15):

if any(len(v) == 0 for v in input_args):
# No quivers, so just make an empty collection and return early
linec = art3d.Line3DCollection([], *args[argi:], **kwargs)
linec = art3d.Line3DCollection([], **kwargs)
self.add_collection(linec)
return linec

Expand All @@ -2622,7 +2613,7 @@ def calc_arrows(UVW, angle=15):
shaft_dt -= length / 2

XYZ = np.column_stack(input_args[:3])
UVW = np.column_stack(input_args[3:argi]).astype(float)
UVW = np.column_stack(input_args[3:]).astype(float)
69DD
# Normalize rows of UVW
norm = np.linalg.norm(UVW, axis=1)
Expand Down Expand Up @@ -2651,7 +2642,7 @@ def calc_arrows(UVW, angle=15):
else:
lines = []

linec = art3d.Line3DCollection(lines, *args[argi:], **kwargs)
linec = art3d.Line3DCollection(lines, **kwargs)
self.add_collection(linec)

self.auto_scale_xyz(XYZ[:, 0], XYZ[:, 1], XYZ[:, 2], had_data)
Expand Down
0