diff --git a/doc/users/next_whats_new/allow_changing_the_vertical_axis_in_3d_plots.rst b/doc/users/next_whats_new/allow_changing_the_vertical_axis_in_3d_plots.rst new file mode 100644 index 000000000000..fead9f359f56 --- /dev/null +++ b/doc/users/next_whats_new/allow_changing_the_vertical_axis_in_3d_plots.rst @@ -0,0 +1,5 @@ +Allow changing the vertical axis in 3d plots +---------------------------------------------- + +`~mpl_toolkits.mplot3d.axes3d.Axes3D.view_init` now has the parameter +*vertical_axis* which allows switching which axis is aligned vertically. diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index ed97b9658560..1f0f56fa780d 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -1045,17 +1045,24 @@ def clabel(self, *args, **kwargs): """Currently not implemented for 3D axes, and returns *None*.""" return None - def view_init(self, elev=None, azim=None): + def view_init(self, elev=None, azim=None, vertical_axis="z"): """ Set the elevation and azimuth of the axes in degrees (not radians). This can be used to rotate the axes programmatically. - 'elev' stores the elevation angle in the z plane (in degrees). - 'azim' stores the azimuth angle in the (x, y) plane (in degrees). - - if 'elev' or 'azim' are None (default), then the initial value - is used which was specified in the :class:`Axes3D` constructor. + Parameters + ---------- + elev : float, default: None + The elevation angle in the vertical plane in degrees. + If None then the initial value as specified in the `Axes3D` + constructor is used. + azim : float, default: None + The azimuth angle in the horizontal plane in degrees. + If None then the initial value as specified in the `Axes3D` + constructor is used. + vertical_axis : {"z", "x", "y"}, default: "z" + The axis to align vertically. *azim* rotates about this axis. """ self.dist = 10 @@ -1070,6 +1077,10 @@ def view_init(self, elev=None, azim=None): else: self.azim = azim + self._vertical_axis = _api.check_getitem( + dict(x=0, y=1, z=2), vertical_axis=vertical_axis + ) + def set_proj_type(self, proj_type): """ Set the projection type. @@ -1083,47 +1094,60 @@ def set_proj_type(self, proj_type): 'ortho': proj3d.ortho_transformation, }, proj_type=proj_type) + def _roll_to_vertical(self, arr): + """Roll arrays to match the different vertical axis.""" + return np.roll(arr, self._vertical_axis - 2) + def get_proj(self): """Create the projection matrix from the current viewing position.""" - # elev stores the elevation angle in the z plane - # azim stores the azimuth angle in the x,y plane - # - # dist is the distance of the eye viewing point from the object - # point. - relev, razim = np.pi * self.elev/180, np.pi * self.azim/180 - - xmin, xmax = self.get_xlim3d() - ymin, ymax = self.get_ylim3d() - zmin, zmax = self.get_zlim3d() - - # transform to uniform world coordinates 0-1, 0-1, 0-1 - worldM = proj3d.world_transformation(xmin, xmax, - ymin, ymax, - zmin, zmax, - pb_aspect=self._box_aspect) - - # look into the middle of the new coordinates - R = self._box_aspect / 2 + # Transform to uniform world coordinates 0-1, 0-1, 0-1 + box_aspect = self._roll_to_vertical(self._box_aspect) + worldM = proj3d.world_transformation( + *self.get_xlim3d(), + *self.get_ylim3d(), + *self.get_zlim3d(), + pb_aspect=box_aspect, + ) - xp = R[0] + np.cos(razim) * np.cos(relev) * self.dist - yp = R[1] + np.sin(razim) * np.cos(relev) * self.dist - zp = R[2] + np.sin(relev) * self.dist - E = np.array((xp, yp, zp)) + # Look into the middle of the new coordinates: + R = 0.5 * box_aspect - self.eye = E - self.vvec = R - E + # elev stores the elevation angle in the z plane + # azim stores the azimuth angle in the x,y plane + elev_rad = np.deg2rad(self.elev) + azim_rad = np.deg2rad(self.azim) + + # Coordinates for a point that rotates around the box of data. + # p0, p1 corresponds to rotating the box only around the + # vertical axis. + # p2 corresponds to rotating the box only around the horizontal + # axis. + p0 = np.cos(elev_rad) * np.cos(azim_rad) + p1 = np.cos(elev_rad) * np.sin(azim_rad) + p2 = np.sin(elev_rad) + + # When changing vertical axis the coordinates changes as well. + # Roll the values to get the same behaviour as the default: + ps = self._roll_to_vertical([p0, p1, p2]) + + # The coordinates for the eye viewing point. The eye is looking + # towards the middle of the box of data from a distance: + eye = R + self.dist * ps + + # TODO: Is this being used somewhere? Can it be removed? + self.eye = eye + self.vvec = R - eye self.vvec = self.vvec / np.linalg.norm(self.vvec) - if abs(relev) > np.pi/2: - # upside down - V = np.array((0, 0, -1)) - else: - V = np.array((0, 0, 1)) - zfront, zback = -self.dist, self.dist + # Define which axis should be vertical. A negative value + # indicates the plot is upside down and therefore the values + # have been reversed: + V = np.zeros(3) + V[self._vertical_axis] = -1 if abs(elev_rad) > 0.5 * np.pi else 1 - viewM = proj3d.view_transformation(E, R, V) - projM = self._projection(zfront, zback) + viewM = proj3d.view_transformation(eye, R, V) + projM = self._projection(-self.dist, self.dist) M0 = np.dot(viewM, worldM) M = np.dot(projM, M0) return M diff --git a/lib/mpl_toolkits/mplot3d/axis3d.py b/lib/mpl_toolkits/mplot3d/axis3d.py index c9d1f2801f3b..86d8e3b183a1 100644 --- a/lib/mpl_toolkits/mplot3d/axis3d.py +++ b/lib/mpl_toolkits/mplot3d/axis3d.py @@ -206,6 +206,45 @@ def _get_coord_info(self, renderer): return mins, maxs, centers, deltas, bounds_proj, highs + def _get_axis_line_edge_points(self, minmax, maxmin): + """Get the edge points for the black bolded axis line.""" + # When changing vertical axis some of the axes has to be + # moved to the other plane so it looks the same as if the z-axis + # was the vertical axis. + mb = [minmax, maxmin] + mb_rev = mb[::-1] + mm = [[mb, mb_rev, mb_rev], [mb_rev, mb_rev, mb], [mb, mb, mb]] + mm = mm[self.axes._vertical_axis][self._axinfo["i"]] + + juggled = self._axinfo["juggled"] + edge_point_0 = mm[0].copy() + edge_point_0[juggled[0]] = mm[1][juggled[0]] + + edge_point_1 = edge_point_0.copy() + edge_point_1[juggled[1]] = mm[1][juggled[1]] + + return edge_point_0, edge_point_1 + + def _get_tickdir(self): + """ + Get the direction of the tick. + + Returns + ------- + tickdir : int + Index which indicates which coordinate the tick line will + align with. + """ + # TODO: Move somewhere else where it's triggered less: + tickdirs_base = [v["tickdir"] for v in self._AXINFO.values()] + info_i = [v["i"] for v in self._AXINFO.values()] + + i = self._axinfo["i"] + j = self.axes._vertical_axis - 2 + # tickdir = [[1, 2, 1], [2, 2, 0], [1, 0, 0]][i] + tickdir = np.roll(info_i, -j)[np.roll(tickdirs_base, j)][i] + return tickdir + def draw_pane(self, renderer): renderer.open_group('pane3d', gid=self.get_gid()) @@ -226,29 +265,27 @@ def draw_pane(self, renderer): @artist.allow_rasterization def draw(self, renderer): self.label._transform = self.axes.transData - renderer.open_group('axis3d', gid=self.get_gid()) + renderer.open_group("axis3d", gid=self.get_gid()) ticks = self._update_ticks() + # Get general axis information: info = self._axinfo - index = info['i'] + index = info["i"] + juggled = info["juggled"] mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer) - # Determine grid lines minmax = np.where(highs, maxs, mins) - maxmin = np.where(highs, mins, maxs) + maxmin = np.where(~highs, maxs, mins) - # Draw main axis line - juggled = info['juggled'] - edgep1 = minmax.copy() - edgep1[juggled[0]] = maxmin[juggled[0]] + # Create edge points for the black bolded axis line: + edgep1, edgep2 = self._get_axis_line_edge_points(minmax, maxmin) - edgep2 = edgep1.copy() - edgep2[juggled[1]] = maxmin[juggled[1]] - pep = np.asarray( - proj3d.proj_trans_points([edgep1, edgep2], self.axes.M)) - centpt = proj3d.proj_transform(*centers, self.axes.M) + # Project the edge points along the current position and + # create the line: + pep = proj3d.proj_trans_points([edgep1, edgep2], self.axes.M) + pep = np.asarray(pep) self.line.set_data(pep[0], pep[1]) self.line.draw(renderer) @@ -325,6 +362,7 @@ def draw(self, renderer): # Three-letters (e.g., TFT, FTT) are short-hand for the array of bools # from the variable 'highs'. # --------------------------------------------------------------------- + centpt = proj3d.proj_transform(*centers, self.axes.M) if centpt[info['tickdir']] > pep[info['tickdir'], outerindex]: # if FT and if highs has an even number of Trues if (centpt[index] <= pep[index, outerindex] @@ -370,8 +408,8 @@ def draw(self, renderer): self.gridlines.do_3d_projection() self.gridlines.draw(renderer) - # Draw ticks - tickdir = info['tickdir'] + # Draw ticks: + tickdir = self._get_tickdir() tickdelta = deltas[tickdir] if highs[tickdir]: ticksign = 1 diff --git a/lib/mpl_toolkits/tests/test_mplot3d.py b/lib/mpl_toolkits/tests/test_mplot3d.py index 9717d8bf1c58..edb5c978bb49 100644 --- a/lib/mpl_toolkits/tests/test_mplot3d.py +++ b/lib/mpl_toolkits/tests/test_mplot3d.py @@ -1523,3 +1523,93 @@ def test_scatter_spiral(): # force at least 1 draw! fig.canvas.draw() + + +@pytest.mark.parametrize( + "vertical_axis, proj_expected, axis_lines_expected, tickdirs_expected", + [ + ( + "z", + [ + [0.0, 1.142857, 0.0, -0.571429], + [0.0, 0.0, 0.857143, -0.428571], + [0.0, 0.0, 0.0, -10.0], + [-1.142857, 0.0, 0.0, 10.571429], + ], + [ + ([0.05617978, 0.06329114], [-0.04213483, -0.04746835]), + ([-0.06329114, 0.06329114], [-0.04746835, -0.04746835]), + ([-0.06329114, -0.06329114], [-0.04746835, 0.04746835]), + ], + [1, 0, 0], + ), + ( + "y", + [ + [1.142857, 0.0, 0.0, -0.571429], + [0.0, 0.857143, 0.0, -0.428571], + [0.0, 0.0, 0.0, -10.0], + [0.0, 0.0, -1.142857, 10.571429], + ], + [ + ([0.06329114, -0.06329114], [-0.04746835, -0.04746835]), + ([-0.06329114, -0.06329114], [0.04746835, -0.04746835]), + ([0.05617978, 0.06329114], [-0.04213483, -0.04746835]), + ], + [2, 2, 0], + ), + ( + "x", + [ + [0.0, 0.0, 1.142857, -0.571429], + [0.857143, 0.0, 0.0, -0.428571], + [0.0, 0.0, 0.0, -10.0], + [0.0, -1.142857, 0.0, 10.571429], + ], + [ + ([-0.06329114, -0.06329114], [-0.04746835, 0.04746835]), + ([0.06329114, 0.05617978], [-0.04746835, -0.04213483]), + ([0.06329114, -0.06329114], [-0.04746835, -0.04746835]), + ], + [1, 2, 1], + ), + ], +) +def test_view_init_vertical_axis( + vertical_axis, proj_expected, axis_lines_expected, tickdirs_expected +): + """ + Test the actual projection, axis lines and ticks matches expected values. + + Parameters + ---------- + vertical_axis : str + Axis to align vertically. + proj_expected : ndarray + Expected values from ax.get_proj(). + axis_lines_expected : tuple of arrays + Edgepoints of the axis line. Expected values retrieved according + to ``ax.get_[xyz]axis().line.get_data()``. + tickdirs_expected : list of int + indexes indicating which axis to create a tick line along. + """ + rtol = 2e-06 + ax = plt.subplot(1, 1, 1, projection="3d") + ax.view_init(azim=0, elev=0, vertical_axis=vertical_axis) + ax.figure.canvas.draw() + + # Assert the projection matrix: + proj_actual = ax.get_proj() + np.testing.assert_allclose(proj_expected, proj_actual, rtol=rtol) + + for i, axis in enumerate([ax.get_xaxis(), ax.get_yaxis(), ax.get_zaxis()]): + # Assert black lines are correctly aligned: + axis_line_expected = axis_lines_expected[i] + axis_line_actual = axis.line.get_data() + np.testing.assert_allclose(axis_line_expected, axis_line_actual, + rtol=rtol) + + # Assert ticks are correctly aligned: + tickdir_expected = tickdirs_expected[i] + tickdir_actual = axis._get_tickdir() + np.testing.assert_array_equal(tickdir_expected, tickdir_actual)