diff --git a/doc/users/next_whats_new/3d_axis_positions.rst b/doc/users/next_whats_new/3d_axis_positions.rst new file mode 100644 index 000000000000..e4e09eb3afad --- /dev/null +++ b/doc/users/next_whats_new/3d_axis_positions.rst @@ -0,0 +1,20 @@ +Specify ticks and axis label positions for 3D plots +--------------------------------------------------- + +You can now specify the positions of ticks and axis labels for 3D plots. + +.. plot:: + :include-source: + + import matplotlib.pyplot as plt + + positions = ['lower', 'upper', 'default', 'both', 'none'] + fig, axs = plt.subplots(2, 3, figsize=(12, 8), + subplot_kw={'projection': '3d'}) + for ax, pos in zip(axs.flatten(), positions): + for axis in ax.xaxis, ax.yaxis, ax.zaxis: + axis.set_label_position(pos) + axis.set_ticks_position(pos) + title = f'position="{pos}"' + ax.set(xlabel='x', ylabel='y', zlabel='z', title=title) + axs[1, 2].axis('off') diff --git a/lib/mpl_toolkits/mplot3d/axis3d.py b/lib/mpl_toolkits/mplot3d/axis3d.py index 30f56c70f9f5..58792deae963 100644 --- a/lib/mpl_toolkits/mplot3d/axis3d.py +++ b/lib/mpl_toolkits/mplot3d/axis3d.py @@ -76,6 +76,9 @@ def __init__(self, *args, **kwargs): name = self.axis_name + self._label_position = 'default' + self._tick_position = 'default' + # This is a temporary member variable. # Do not depend on this existing in future releases! self._axinfo = self._AXINFO[name].copy() @@ -183,6 +186,64 @@ def get_minor_ticks(self, numticks=None): obj.set_transform(self.axes.transData) return ticks + def set_ticks_position(self, position): + """ + Set the ticks position. + + Parameters + ---------- + str : {'lower', 'upper', 'both', 'default', 'none'} + The position of the bolded axis lines, ticks, and tick labels. + """ + if position in ['top', 'bottom']: + _api.warn_deprecated('3.8', name=f'{position=}', + obj_type='argument value', + alternative="'upper' or 'lower'") + return + _api.check_in_list(['lower', 'upper', 'both', 'default', 'none'], + position=position) + self._tick_position = position + + def get_ticks_position(self): + """ + Get the ticks position. + + Returns + ------- + str : {'lower', 'upper', 'both', 'default', 'none'} + The position of the bolded axis lines, ticks, and tick labels. + """ + return self._tick_position + + def set_label_position(self, position): + """ + Set the label position. + + Parameters + ---------- + str : {'lower', 'upper', 'both', 'default', 'none'} + The position of the axis label. + """ + if position in ['top', 'bottom']: + _api.warn_deprecated('3.8', name=f'{position=}', + obj_type='argument value', + alternative="'upper' or 'lower'") + return + _api.check_in_list(['lower', 'upper', 'both', 'default', 'none'], + position=position) + self._label_position = position + + def get_label_position(self): + """ + Get the label position. + + Returns + ------- + str : {'lower', 'upper', 'both', 'default', 'none'} + The position of the axis label. + """ + return self._label_position + def set_pane_color(self, color, alpha=None): """ Set pane color. @@ -225,8 +286,7 @@ def _get_coord_info(self, renderer): # Get the mean value for each bound: centers = 0.5 * (maxs + mins) - # Add a small offset between min/max point and the edge of the - # plot: + # Add a small offset between min/max point and the edge of the plot: deltas = (maxs - mins) / 12 mins -= 0.25 * deltas maxs += 0.25 * deltas @@ -256,42 +316,96 @@ def _get_coord_info(self, renderer): return mins, maxs, centers, deltas, bounds_proj, highs - def _get_axis_line_edge_points(self, minmax, maxmin): + def _get_axis_line_edge_points(self, minmax, maxmin, position=None): """Get the edge points for the black bolded axis line.""" # When changing vertical axis some of the axes has to be # moved to the other plane so it looks the same as if the z-axis # was the vertical axis. - mb = [minmax, maxmin] + mb = [minmax, maxmin] # line from origin to nearest corner to camera mb_rev = mb[::-1] mm = [[mb, mb_rev, mb_rev], [mb_rev, mb_rev, mb], [mb, mb, mb]] mm = mm[self.axes._vertical_axis][self._axinfo["i"]] juggled = self._axinfo["juggled"] - edge_point_0 = mm[0].copy() - edge_point_0[juggled[0]] = mm[1][juggled[0]] + edge_point_0 = mm[0].copy() # origin point + + if ((position == 'lower' and mm[1][juggled[-1]] < mm[0][juggled[-1]]) or + (position == 'upper' and mm[1][juggled[-1]] > mm[0][juggled[-1]])): + edge_point_0[juggled[-1]] = mm[1][juggled[-1]] + else: + edge_point_0[juggled[0]] = mm[1][juggled[0]] edge_point_1 = edge_point_0.copy() edge_point_1[juggled[1]] = mm[1][juggled[1]] return edge_point_0, edge_point_1 - def _get_tickdir(self): + def _get_all_axis_line_edge_points(self, minmax, maxmin, axis_position=None): + # Determine edge points for the axis lines + edgep1s = [] + edgep2s = [] + position = [] + if axis_position in (None, 'default'): + edgep1, edgep2 = self._get_axis_line_edge_points(minmax, maxmin) + edgep1s = [edgep1] + edgep2s = [edgep2] + position = ['default'] + else: + edgep1_l, edgep2_l = self._get_axis_line_edge_points(minmax, maxmin, + position='lower') + edgep1_u, edgep2_u = self._get_axis_line_edge_points(minmax, maxmin, + position='upper') + if axis_position in ('lower', 'both'): + edgep1s.append(edgep1_l) + edgep2s.append(edgep2_l) + position.append('lower') + if axis_position in ('upper', 'both'): + edgep1s.append(edgep1_u) + edgep2s.append(edgep2_u) + position.append('upper') + return edgep1s, edgep2s, position + + def _get_tickdir(self, position): """ Get the direction of the tick. + Parameters + ---------- + position : str, optional : {'upper', 'lower', 'default'} + The position of the axis. + Returns ------- tickdir : int Index which indicates which coordinate the tick line will align with. """ + _api.check_in_list(('upper', 'lower', 'default'), position=position) + # TODO: Move somewhere else where it's triggered less: - tickdirs_base = [v["tickdir"] for v in self._AXINFO.values()] + tickdirs_base = [v["tickdir"] for v in self._AXINFO.values()] # default + elev_mod = np.mod(self.axes.elev + 180, 360) - 180 + azim_mod = np.mod(self.axes.azim, 360) + if position == 'upper': + if elev_mod >= 0: + tickdirs_base = [2, 2, 0] + else: + tickdirs_base = [1, 0, 0] + if 0 <= azim_mod < 180: + tickdirs_base[2] = 1 + elif position == 'lower': + if elev_mod >= 0: + tickdirs_base = [1, 0, 1] + else: + tickdirs_base = [2, 2, 1] + if 0 <= azim_mod < 180: + tickdirs_base[2] = 0 info_i = [v["i"] for v in self._AXINFO.values()] i = self._axinfo["i"] - j = self.axes._vertical_axis - 2 - # tickdir = [[1, 2, 1], [2, 2, 0], [1, 0, 0]][i] + vert_ax = self.axes._vertical_axis + j = vert_ax - 2 + # default: tickdir = [[1, 2, 1], [2, 2, 0], [1, 0, 0]][vert_ax][i] tickdir = np.roll(info_i, -j)[np.roll(tickdirs_base, j)][i] return tickdir @@ -322,69 +436,58 @@ def draw_pane(self, renderer): self.pane.draw(renderer) renderer.close_group('pane3d') - @artist.allow_rasterization - def draw(self, renderer): - self.label._transform = self.axes.transData - self.offsetText._transform = self.axes.transData - renderer.open_group("axis3d", gid=self.get_gid()) + def _axmask(self): + axmask = [True, True, True] + axmask[self._axinfo["i"]] = False + return axmask + def _draw_ticks(self, renderer, edgep1, centers, deltas, highs, + deltas_per_point, pos): ticks = self._update_ticks() - - # Get general axis information: info = self._axinfo index = info["i"] - juggled = info["juggled"] - - mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer) - minmax = np.where(highs, maxs, mins) - maxmin = np.where(~highs, maxs, mins) + # Draw ticks: + tickdir = self._get_tickdir(pos) + tickdelta = deltas[tickdir] if highs[tickdir] else -deltas[tickdir] - # Create edge points for the black bolded axis line: - edgep1, edgep2 = self._get_axis_line_edge_points(minmax, maxmin) + tick_info = info['tick'] + tick_out = tick_info['outward_factor'] * tickdelta + tick_in = tick_info['inward_factor'] * tickdelta + tick_lw = tick_info['linewidth'] + edgep1_tickdir = edgep1[tickdir] + out_tickdir = edgep1_tickdir + tick_out + in_tickdir = edgep1_tickdir - tick_in - # Project the edge points along the current position and - # create the line: - pep = proj3d._proj_trans_points([edgep1, edgep2], self.axes.M) - pep = np.asarray(pep) - self.line.set_data(pep[0], pep[1]) - self.line.draw(renderer) + default_label_offset = 8. # A rough estimate + points = deltas_per_point * deltas + for tick in ticks: + # Get tick line positions + pos = edgep1.copy() + pos[index] = tick.get_loc() + pos[tickdir] = out_tickdir + x1, y1, z1 = proj3d.proj_transform(*pos, self.axes.M) + pos[tickdir] = in_tickdir + x2, y2, z2 = proj3d.proj_transform(*pos, self.axes.M) - # Draw labels - # The transAxes transform is used because the Text object - # rotates the text relative to the display coordinate system. - # Therefore, if we want the labels to remain parallel to the - # axis regardless of the aspect ratio, we need to convert the - # edge points of the plane to display coordinates and calculate - # an angle from that. - # TODO: Maybe Text objects should handle this themselves? - dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) - - self.axes.transAxes.transform([pep[0:2, 0]]))[0] + # Get position of label + labeldeltas = (tick.get_pad() + default_label_offset) * points - lxyz = 0.5 * (edgep1 + edgep2) + pos[tickdir] = edgep1_tickdir + pos = _move_from_center(pos, centers, labeldeltas, self._axmask()) + lx, ly, lz = proj3d.proj_transform(*pos, self.axes.M) - # A rough estimate; points are ambiguous since 3D plots rotate - reltoinches = self.figure.dpi_scale_trans.inverted() - ax_inches = reltoinches.transform(self.axes.bbox.size) - ax_points_estimate = sum(72. * ax_inches) - deltas_per_point = 48 / ax_points_estimate - default_offset = 21. - labeldeltas = ( - (self.labelpad + default_offset) * deltas_per_point * deltas) - axmask = [True, True, True] - axmask[index] = False - lxyz = _move_from_center(lxyz, centers, labeldeltas, axmask) - tlx, tly, tlz = proj3d.proj_transform(*lxyz, self.axes.M) - self.label.set_position((tlx, tly)) - if self.get_rotate_label(self.label.get_text()): - angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx))) - self.label.set_rotation(angle) - self.label.set_va(info['label']['va']) - self.label.set_ha(info['label']['ha']) - self.label.set_rotation_mode(info['label']['rotation_mode']) - self.label.draw(renderer) + _tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly)) + tick.tick1line.set_linewidth(tick_lw[tick._major]) + tick.draw(renderer) - # Draw Offset text + def _draw_offset_text(self, renderer, edgep1, edgep2, labeldeltas, centers, + highs, pep, dx, dy): + # Get general axis information: + info = self._axinfo + index = info["i"] + juggled = info["juggled"] + tickdir = info["tickdir"] # Which of the two edge points do we want to # use for locating the offset text? @@ -395,7 +498,8 @@ def draw(self, renderer): outeredgep = edgep2 outerindex = 1 - pos = _move_from_center(outeredgep, centers, labeldeltas, axmask) + pos = _move_from_center(outeredgep, centers, labeldeltas, + self._axmask()) olx, oly, olz = proj3d.proj_transform(*pos, self.axes.M) self.offsetText.set_text(self.major.formatter.get_offset()) self.offsetText.set_position((olx, oly)) @@ -414,14 +518,14 @@ def draw(self, renderer): # using the wrong reference points). # # (TT, FF, TF, FT) are the shorthand for the tuple of - # (centpt[info['tickdir']] <= pep[info['tickdir'], outerindex], + # (centpt[tickdir] <= pep[tickdir, outerindex], # centpt[index] <= pep[index, outerindex]) # # Three-letters (e.g., TFT, FTT) are short-hand for the array of bools # from the variable 'highs'. # --------------------------------------------------------------------- centpt = proj3d.proj_transform(*centers, self.axes.M) - if centpt[info['tickdir']] > pep[info['tickdir'], outerindex]: + if centpt[tickdir] > pep[tickdir, outerindex]: # if FT and if highs has an even number of Trues if (centpt[index] <= pep[index, outerindex] and np.count_nonzero(highs) % 2 == 0): @@ -448,39 +552,82 @@ def draw(self, renderer): self.offsetText.set_ha(align) self.offsetText.draw(renderer) - # Draw ticks: - tickdir = self._get_tickdir() - tickdelta = deltas[tickdir] if highs[tickdir] else -deltas[tickdir] + def _draw_labels(self, renderer, edgep1, edgep2, labeldeltas, centers, dx, dy): + label = self._axinfo["label"] - tick_info = info['tick'] - tick_out = tick_info['outward_factor'] * tickdelta - tick_in = tick_info['inward_factor'] * tickdelta - tick_lw = tick_info['linewidth'] - edgep1_tickdir = edgep1[tickdir] - out_tickdir = edgep1_tickdir + tick_out - in_tickdir = edgep1_tickdir - tick_in - - default_label_offset = 8. # A rough estimate - points = deltas_per_point * deltas - for tick in ticks: - # Get tick line positions - pos = edgep1.copy() - pos[index] = tick.get_loc() - pos[tickdir] = out_tickdir - x1, y1, z1 = proj3d.proj_transform(*pos, self.axes.M) - pos[tickdir] = in_tickdir - x2, y2, z2 = proj3d.proj_transform(*pos, self.axes.M) + # Draw labels + lxyz = 0.5 * (edgep1 + edgep2) + lxyz = _move_from_center(lxyz, centers, labeldeltas, self._axmask()) + tlx, tly, tlz = proj3d.proj_transform(*lxyz, self.axes.M) + self.label.set_position((tlx, tly)) + if self.get_rotate_label(self.label.get_text()): + angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx))) + self.label.set_rotation(angle) + self.label.set_va(label['va']) + self.label.set_ha(label['ha']) + self.label.set_rotation_mode(label['rotation_mode']) + self.label.draw(renderer) - # Get position of label - labeldeltas = (tick.get_pad() + default_label_offset) * points + @artist.allow_rasterization + def draw(self, renderer): + self.label._transform = self.axes.transData + self.offsetText._transform = self.axes.transData + renderer.open_group("axis3d", gid=self.get_gid()) - pos[tickdir] = edgep1_tickdir - pos = _move_from_center(pos, centers, labeldeltas, axmask) - lx, ly, lz = proj3d.proj_transform(*pos, self.axes.M) + # Get general axis information: + mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer) - _tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly)) - tick.tick1line.set_linewidth(tick_lw[tick._major]) - tick.draw(renderer) + # Calculate offset distances + # A rough estimate; points are ambiguous since 3D plots rotate + reltoinches = self.figure.dpi_scale_trans.inverted() + ax_inches = reltoinches.transform(self.axes.bbox.size) + ax_points_estimate = sum(72. * ax_inches) + deltas_per_point = 48 / ax_points_estimate + default_offset = 21. + labeldeltas = (self.labelpad + default_offset) * deltas_per_point * deltas + + # Determine edge points for the axis lines + minmax = np.where(highs, maxs, mins) # "origin" point + maxmin = np.where(~highs, maxs, mins) # "opposite" corner near camera + + for edgep1, edgep2, pos in zip(*self._get_all_axis_line_edge_points( + minmax, maxmin, self._tick_position)): + # Project the edge points along the current position + pep = proj3d._proj_trans_points([edgep1, edgep2], self.axes.M) + pep = np.asarray(pep) + + # The transAxes transform is used because the Text object + # rotates the text relative to the display coordinate system. + # Therefore, if we want the labels to remain parallel to the + # axis regardless of the aspect ratio, we need to convert the + # edge points of the plane to display coordinates and calculate + # an angle from that. + # TODO: Maybe Text objects should handle this themselves? + dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) - + self.axes.transAxes.transform([pep[0:2, 0]]))[0] + + # Draw the lines + self.line.set_data(pep[0], pep[1]) + self.line.draw(renderer) + + # Draw ticks + self._draw_ticks(renderer, edgep1, centers, deltas, highs, + deltas_per_point, pos) + + # Draw Offset text + self._draw_offset_text(renderer, edgep1, edgep2, labeldeltas, + centers, highs, pep, dx, dy) + + for edgep1, edgep2, pos in zip(*self._get_all_axis_line_edge_points( + minmax, maxmin, self._label_position)): + # See comments above + pep = proj3d._proj_trans_points([edgep1, edgep2], self.axes.M) + pep = np.asarray(pep) + dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) - + self.axes.transAxes.transform([pep[0:2, 0]]))[0] + + # Draw labels + self._draw_labels(renderer, edgep1, edgep2, labeldeltas, centers, dx, dy) renderer.close_group('axis3d') self.stale = False @@ -534,7 +681,7 @@ def get_tightbbox(self, renderer=None, *, for_layout_only=False): # (and hope they are up to date) because at draw time we # shift the ticks and their labels around in (x, y) space # based on the projection, the current view port, and their - # position in 3D space. If we extend the transforms framework + # position in 3D space. If we extend the transforms framework # into 3D we would not need to do this different book keeping # than we do in the normal axis major_locs = self.get_majorticklocs() diff --git a/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/axis_positions.png b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/axis_positions.png new file mode 100644 index 000000000000..d4155479d213 Binary files /dev/null and b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/axis_positions.png differ diff --git a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py index 671ce68b1e93..c00a96b421c7 100644 --- a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py +++ b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py @@ -57,6 +57,18 @@ def test_invisible_ticks_axis(): axis.line.set_visible(False) +@mpl3d_image_comparison(['axis_positions.png'], remove_text=False, style='mpl20') +def test_axis_positions(): + positions = ['upper', 'lower', 'both', 'none'] + fig, axs = plt.subplots(2, 2, subplot_kw={'projection': '3d'}) + for ax, pos in zip(axs.flatten(), positions): + for axis in ax.xaxis, ax.yaxis, ax.zaxis: + axis.set_label_position(pos) + axis.set_ticks_position(pos) + title = f'{pos}' + ax.set(xlabel='x', ylabel='y', zlabel='z', title=title) + + @mpl3d_image_comparison(['aspects.png'], remove_text=False, style='mpl20') def test_aspects(): aspects = ('auto', 'equal', 'equalxy', 'equalyz', 'equalxz', 'equal') @@ -2156,7 +2168,7 @@ def test_view_init_vertical_axis( # Assert ticks are correctly aligned: tickdir_expected = tickdirs_expected[i] - tickdir_actual = axis._get_tickdir() + tickdir_actual = axis._get_tickdir('default') np.testing.assert_array_equal(tickdir_expected, tickdir_actual)