8000 Allow changing the vertical axis in 3d plots (#19873) · matplotlib/matplotlib@b12d983 · GitHub
[go: up one dir, main page]

Skip to content

Commit b12d983

Browse files
IllviljantimhoffmQuLogic
authored
Allow changing the vertical axis in 3d plots (#19873)
* allow changing the vertical axis * change aspect to follow vertical axis, * add projection test * Update axes3d.py * Update test_mplot3d.py * axis lines behaves the same as default * Update axis3d.py * Rotate ticks correctly. * undo comments for unchanged code * generalize slightly * add tickdir test * Update lib/mpl_toolkits/mplot3d/axes3d.py Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Update axes3d.py * move func to method * Update lib/mpl_toolkits/mplot3d/axes3d.py Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Update lib/mpl_toolkits/mplot3d/axes3d.py Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Update lib/mpl_toolkits/tests/test_mplot3d.py Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Update lib/mpl_toolkits/mplot3d/axis3d.py Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Update lib/mpl_toolkits/mplot3d/axes3d.py Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Update lib/mpl_toolkits/mplot3d/axes3d.py Co-authored-by: Elliott Sales de Andrade <quantum.analyst@gmail.com> * docstring styling * Add test for axis lines. * Update axis3d.py * Add whats new * filename typo * Update lib/mpl_toolkits/mplot3d/axis3d.py Co-authored-by: Elliott Sales de Andrade <quantum.analyst@gmail.com> * Update lib/mpl_toolkits/tests/test_mplot3d.py Co-authored-by: Elliott Sales de Andrade <quantum.analyst@gmail.com> * Update lib/mpl_toolkits/mplot3d/axes3d.py Co-authored-by: Elliott Sales de Andrade <quantum.analyst@gmail.com> * Update lib/mpl_toolkits/mplot3d/axis3d.py Co-authored-by: Elliott Sales de Andrade <quantum.analyst@gmail.com> * Use _api.check_getitem * Update lib/mpl_toolkits/tests/test_mplot3d.py Co-authored-by: Elliott Sales de Andrade <quantum.analyst@gmail.com> * Update lib/mpl_toolkits/tests/test_mplot3d.py Co-authored-by: Elliott Sales de Andrade <quantum.analyst@gmail.com> Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> Co-authored-by: Elliott Sales de Andrade <quantum.analyst@gmail.com>
1 parent 6f9652b commit b12d983

File tree

4 files changed

+211
-54
lines changed

4 files changed

+211
-54
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Allow changing the vertical axis in 3d plots
2+
----------------------------------------------
3+
4+
`~mpl_toolkits.mplot3d.axes3d.Axes3D.view_init` now has the parameter
5+
*vertical_axis* which allows switching which axis is aligned vertically.

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,17 +1046,24 @@ def clabel(self, *args, **kwargs):
10461046
"""Currently not implemented for 3D axes, and returns *None*."""
10471047
return None
10481048

1049-
def view_init(self, elev=None, azim=None):
1049+
def view_init(self, elev=None, azim=None, vertical_axis="z"):
10501050
"""
10511051
Set the elevation and azimuth of the axes in degrees (not radians).
10521052
10531053
This can be used to rotate the axes programmatically.
10541054
1055-
'elev' stores the elevation angle in the z plane (in degrees).
1056-
'azim' stores the azimuth angle in the (x, y) plane (in degrees).
1057-
1058-
if 'elev' or 'azim' are None (default), then the initial value
1059-
is used which was specified in the :class:`Axes3D` constructor.
1055+
Parameters
1056+
----------
1057+
elev : float, default: None
1058+
The elevation angle in the vertical plane in degrees.
1059+
If None then the initial value as specified in the `Axes3D`
1060+
constructor is used.
1061+
azim : float, default: None
1062+
The azimuth angle in the horizontal plane in degrees.
1063+
If None then the initial value as specified in the `Axes3D`
1064+
constructor is used.
1065+
vertical_axis : {"z", "x", "y"}, default: "z"
1066+
The axis to align vertically. *azim* rotates about this axis.
10601067
"""
10611068

10621069
self.dist = 10
@@ -1071,6 +1078,10 @@ def view_init(self, elev=None, azim=None):
10711078
else:
10721079
self.azim = azim
10731080

1081+
self._vertical_axis = _api.check_getitem(
1082+
dict(x=0, y=1, z=2), vertical_axis=vertical_axis
1083+
)
1084+
10741085
def set_proj_type(self, proj_type):
10751086
"""
10761087
Set the projection type.
@@ -1084,47 +1095,60 @@ def set_proj_type(self, proj_type):
10841095
'ortho': proj3d.ortho_transformation,
10851096
}, proj_type=proj_type)
10861097

1098+
def _roll_to_vertical(self, arr):
1099+
"""Roll arrays to match the different vertical axis."""
1100+
return np.roll(arr, self._vertical_axis - 2)
1101+
10871102
def get_proj(self):
10881103
"""Create the projection matrix from the current viewing position."""
1089-
# elev stores the elevation angle in the z plane
1090-
# azim stores the azimuth angle in the x,y plane
1091-
#
1092-
# dist is the distance of the eye viewing point from the object
1093-
# point.
10941104

1095-
relev, razim = np.pi * self.elev/180, np.pi * self.azim/180
1096-
1097-
xmin, xmax = self.get_xlim3d()
1098-
ymin, ymax = self.get_ylim3d()
1099-
zmin, zmax = self.get_zlim3d()
1100-
1101-
# transform to uniform world coordinates 0-1, 0-1, 0-1
1102-
worldM = proj3d.world_transformation(xmin, xmax,
1103-
ymin, ymax,
1104-
zmin, zmax,
1105-
pb_aspect=self._box_aspect)
1106-
1107-
# look into the middle of the new coordinates
1108-
R = self._box_aspect / 2
1105+
# Transform to uniform world coordinates 0-1, 0-1, 0-1
1106+
box_aspect = self._roll_to_vertical(self._box_aspect)
1107+
worldM = proj3d.world_transformation(
1108+
*self.get_xlim3d(),
1109+
*self.get_ylim3d(),
1110+
*self.get_zlim3d(),
1111+
pb_aspect=box_aspect,
1112+
)
11091113

1110-
xp = R[0] + np.cos(razim) * np.cos(relev) * self.dist
1111-
yp = R[1] + np.sin(razim) * np.cos(relev) * self.dist
1112-
zp = R[2] + np.sin(relev) * self.dist
1113-
E = np.array((xp, yp, zp))
1114+
# Look into the middle of the new coordinates:
1115+
R = 0.5 * box_aspect
11141116

1115-
self.eye = E
1116-
self.vvec = R - E
1117+
# elev stores the elevation angle in the z plane
1118+
# azim stores the azimuth angle in the x,y plane
1119+
elev_rad = np.deg2rad(self.elev)
1120+
azim_rad = np.deg2rad(self.azim)
1121+
1122+
# Coordinates for a point that rotates around the box of data.
1123+
# p0, p1 corresponds to rotating the box only around the
1124+
# vertical axis.
1125+
# p2 corresponds to rotating the box only around the horizontal
1126+
# axis.
1127+
p0 = np.cos(elev_rad) * np.cos(azim_rad)
1128+
p1 = np.cos(elev_rad) * np.sin(azim_rad)
1129+
p2 = np.sin(elev_rad)
1130+
1131+
# When changing vertical axis the coordinates changes as well.
1132+
# Roll the values to get the same behaviour as the default:
1133+
ps = self._roll_to_vertical([p0, p1, p2])
1134+
1135+
# The coordinates for the eye viewing point. The eye is looking
1136+
# towards the middle of the box of data from a distance:
1137+
eye = R + self.dist * ps
1138+
1139+
# TODO: Is this being used somewhere? Can it be removed?
1140+
self.eye = eye
1141+
self.vvec = R - eye
11171142
self.vvec = self.vvec / np.linalg.norm(self.vvec)
11181143

1119-
if abs(relev) > np.pi/2:
1120-
# upside down
1121-
V = np.array((0, 0, -1))
1122-
else:
1123-
V = np.array((0, 0, 1))
1124-
zfront, zback = -self.dist, self.dist
1144+
# Define which axis should be vertical. A negative value
1145+
# indicates the plot is upside down and therefore the values
1146+
# have been reversed:
1147+
V = np.zeros(3)
1148+
V[self._vertical_axis] = -1 if abs(elev_rad) > 0.5 * np.pi else 1
11251149

1126-
viewM = proj3d.view_transformation(E, R, V)
1127-
projM = self._projection(zfront, zback)
1150+
viewM = proj3d.view_transformation(eye, R, V)
1151+
projM = self._projection(-self.dist, self.dist)
11281152
M0 = np.dot(viewM, worldM)
11291153
M = np.dot(projM, M0)
11301154
return M

lib/mpl_toolkits/mplot3d/axis3d.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,45 @@ def _get_coord_info(self, renderer):
206206

207207
return mins, maxs, centers, deltas, bounds_proj, highs
208208

209+
def _get_axis_line_edge_points(self, minmax, maxmin):
210+
"""Get the edge points for the black bolded axis line."""
211+
# When changing vertical axis some of the axes has to be
212+
# moved to the other plane so it looks the same as if the z-axis
213+
# was the vertical axis.
214+
mb = [minmax, maxmin]
215+
mb_rev = mb[::-1]
216+
mm = [[mb, mb_rev, mb_rev], [mb_rev, mb_rev, mb], [mb, mb, mb]]
217+
mm = mm[self.axes._vertical_axis][self._axinfo["i"]]
218+
219+
juggled = self._axinfo["juggled"]
220+
edge_point_0 = mm[0].copy()
221+
edge_point_0[juggled[0]] = mm[1][juggled[0]]
222+
223+
edge_point_1 = edge_point_0.copy()
224+
edge_point_1[juggled[1]] = mm[1][juggled[1]]
225+
226+
return edge_point_0, edge_point_1
227+
228+
def _get_tickdir(self):
229+
"""
230+
Get the direction of the tick.
231+
232+
Returns
233+
-------
234+
tickdir : int
235+
Index which indicates which coordinate the tick line will
236+
align with.
237+
"""
238+
# TODO: Move somewhere else where it's triggered less:
239+
tickdirs_base = [v["tickdir"] for v in self._AXINFO.values()]
240+
info_i = [v["i"] for v in self._AXINFO.values()]
241+
242+
i = self._axinfo["i"]
243+
j = self.axes._vertical_axis - 2
244+
# tickdir = [[1, 2, 1], [2, 2, 0], [1, 0, 0]][i]
245+
tickdir = np.roll(info_i, -j)[np.roll(tickdirs_base, j)][i]
246+
return tickdir
247+
209248
def draw_pane(self, renderer):
210249
renderer.open_group('pane3d', gid=self.get_gid())
211250

@@ -226,29 +265,27 @@ def draw_pane(self, renderer):
226265
@artist.allow_rasterization
227266
def draw(self, renderer):
228267
self.label._transform = self.axes.transData
229-
renderer.open_group('axis3d', gid=self.get_gid())
268+
renderer.open_group("axis3d", gid=self.get_gid())
230269

231270
ticks = self._update_ticks()
232271

272+
# Get general axis information:
233273
info = self._axinfo
234-
index = info['i']
274+
index = info["i"]
275+
juggled = info["juggled"]
235276

236277
mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
237278

238-
# Determine grid lines
239279
minmax = np.where(highs, maxs, mins)
240-
maxmin = np.where(highs, mins, maxs)
280+
maxmin = np.where(~highs, maxs, mins)
241281

242-
# Draw main axis line
243-
juggled = info['juggled']
244-
edgep1 = minmax.copy()
245-
edgep1[juggled[0]] = maxmin[juggled[0]]
282+
# Create edge points for the black bolded axis line:
283+
edgep1, edgep2 = self._get_axis_line_edge_points(minmax, maxmin)
246284

247-
edgep2 = edgep1.copy()
248-
edgep2[juggled[1]] = maxmin[juggled[1]]
249-
pep = np.asarray(
250-
proj3d.proj_trans_points([edgep1, edgep2], self.axes.M))
251-
centpt = proj3d.proj_transform(*centers, self.axes.M)
285+
# Project the edge points along the current position and
286+
# create the line:
287+
pep = proj3d.proj_trans_points([edgep1, edgep2], self.axes.M)
288+
pep = np.asarray(pep)
252289
self.line.set_data(pep[0], pep[1])
253290
self.line.draw(renderer)
254291

@@ -325,6 +362,7 @@ def draw(self, renderer):
325362
# Three-letters (e.g., TFT, FTT) are short-hand for the array of bools
326363
# from the variable 'highs'.
327364
# ---------------------------------------------------------------------
365+
centpt = proj3d.proj_transform(*centers, self.axes.M)
328366
if centpt[info['tickdir']] > pep[info['tickdir'], outerindex]:
329367
# if FT and if highs has an even number of Trues
330368
if (centpt[index] <= pep[index, outerindex]
@@ -370,8 +408,8 @@ def draw(self, renderer):
370408
self.gridlines.do_3d_projection()
371409
self.gridlines.draw(renderer)
372410

373-
# Draw ticks
374-
tickdir = info['tickdir']
411+
# Draw ticks:
412+
tickdir = self._get_tickdir()
375413
tickdelta = deltas[tickdir]
376414
if highs[tickdir]:
377415
ticksign = 1

lib/mpl_toolkits/tests/test_mplot3d.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,3 +1523,93 @@ def test_scatter_spiral():
15231523

15241524
# force at least 1 draw!
15251525
fig.canvas.draw()
1526+
1527+
1528+
@pytest.mark.parametrize(
1529+
"vertical_axis, proj_expected, axis_lines_expected, tickdirs_expected",
1530+
[
1531+
(
1532+
"z",
1533+
[
1534+
[0.0, 1.142857, 0.0, -0.571429],
1535+
[0.0, 0.0, 0.857143, -0.428571],
1536+
[0.0, 0.0, 0.0, -10.0],
1537+
[-1.142857, 0.0, 0.0, 10.571429],
1538+
],
1539+
[
1540+
([0.05617978, 0.06329114], [-0.04213483, -0.04746835]),
1541+
([-0.06329114, 0.06329114], [-0.04746835, -0.04746835]),
1542+
([-0.06329114, -0.06329114], [-0.04746835, 0.04746835]),
1543+
],
1544+
[1, 0, 0],
1545+
),
1546+
(
1547+
"y",
1548+
[
1549+
[1.142857, 0.0, 0.0, -0.571429],
1550+
[0.0, 0.857143, 0.0, -0.428571],
1551+
[0.0, 0.0, 0.0, -10.0],
1552+
[0.0, 0.0, -1.142857, 10.571429],
1553+
],
1554+
[
1555+
([0.06329114, -0.06329114], [-0.04746835, -0.04746835]),
1556+
([-0.06329114, -0.06329114], [0.04746835, -0.04746835]),
1557+
([0.05617978, 0.06329114], [-0.04213483, -0.04746835]),
1558+
],
1559+
[2, 2, 0],
1560+
),
1561+
(
1562+
"x",
1563+
[
1564+
[0.0, 0.0, 1.142857, -0.571429],
1565+
[0.857143, 0.0, 0.0, -0.428571],
1566+
[0.0, 0.0, 0.0, -10.0],
1567+
[0.0, -1.142857, 0.0, 10.571429],
1568+
],
1569+
[
1570+
([-0.06329114, -0.06329114], [-0.04746835, 0.04746835]),
1571+
([0.06329114, 0.05617978], [-0.04746835, -0.04213483]),
1572+
([0.06329114, -0.06329114], [-0.04746835, -0.04746835]),
1573+
],
1574+
[1, 2, 1],
1575+
),
1576+
],
1577+
)
1578+
def test_view_init_vertical_axis(
1579+
vertical_axis, proj_expected, axis_lines_expected, tickdirs_expected
1580+
):
1581+
"""
1582+
Test the actual projection, axis lines and ticks matches expected values.
1583+
1584+
Parameters
1585+
----------
1586+
vertical_axis : str
1587+
Axis to align vertically.
1588+
proj_expected : ndarray
1589+
Expected values from ax.get_proj().
1590+
axis_lines_expected : tuple of arrays
1591+
Edgepoints of the axis line. Expected values retrieved according
1592+
to ``ax.get_[xyz]axis().line.get_data()``.
1593+
tickdirs_expected : list of int
1594+
indexes indicating which axis to create a tick line along.
1595+
"""
1596+
rtol = 2e-06
1597+
ax = plt.subplot(1, 1, 1, projection="3d")
1598+
ax.view_init(azim=0, elev=0, vertical_axis=vertical_axis)
1599+
ax.figure.canvas.draw()
1600+
1601+
# Assert the projection matrix:
1602+
proj_actual = ax.get_proj()
1603+
np.testing.assert_allclose(proj_expected, proj_actual, rtol=rtol)
1604+
1605+
for i, axis in enumerate([ax.get_xaxis(), ax.get_yaxis(), ax.get_zaxis()]):
1606+
# Assert black lines are correctly aligned:
1607+
axis_line_expected = axis_lines_expected[i]
1608+
axis_line_actual = axis.line.get_data()
1609+
np.testing.assert_allclose(axis_line_expected, axis_line_actual,
1610+
rtol=rtol)
1611+
1612+
# Assert ticks are correctly aligned:
1613+
tickdir_expected = tickdirs_expected[i]
1614+
tickdir_actual = axis._get_tickdir()
1615+
np.testing.assert_array_equal(tickdir_expected, tickdir_actual)

0 commit comments

Comments
 (0)
0