diff --git a/doc/users/next_whats_new/allow_changing_the_vertical_axis_in_3d_plots.rst b/doc/users/next_whats_new/allow_changing_the_vertical_axis_in_3d_plots.rst
new file mode 100644
index 000000000000..fead9f359f56
--- /dev/null
+++ b/doc/users/next_whats_new/allow_changing_the_vertical_axis_in_3d_plots.rst
@@ -0,0 +1,5 @@
+Allow changing the vertical axis in 3d plots
+----------------------------------------------
+
+`~mpl_toolkits.mplot3d.axes3d.Axes3D.view_init` now has the parameter
+*vertical_axis* which allows switching which axis is aligned vertically.
diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py
index ed97b9658560..1f0f56fa780d 100644
--- a/lib/mpl_toolkits/mplot3d/axes3d.py
+++ b/lib/mpl_toolkits/mplot3d/axes3d.py
@@ -1045,17 +1045,24 @@ def clabel(self, *args, **kwargs):
         """Currently not implemented for 3D axes, and returns *None*."""
         return None
 
-    def view_init(self, elev=None, azim=None):
+    def view_init(self, elev=None, azim=None, vertical_axis="z"):
         """
         Set the elevation and azimuth of the axes in degrees (not radians).
 
         This can be used to rotate the axes programmatically.
 
-        'elev' stores the elevation angle in the z plane (in degrees).
-        'azim' stores the azimuth angle in the (x, y) plane (in degrees).
-
-        if 'elev' or 'azim' are None (default), then the initial value
-        is used which was specified in the :class:`Axes3D` constructor.
+        Parameters
+        ----------
+        elev : float, default: None
+            The elevation angle in the vertical plane in degrees.
+            If None then the initial value as specified in the `Axes3D`
+            constructor is used.
+        azim : float, default: None
+            The azimuth angle in the horizontal plane in degrees.
+            If None then the initial value as specified in the `Axes3D`
+            constructor is used.
+        vertical_axis : {"z", "x", "y"}, default: "z"
+            The axis to align vertically. *azim* rotates about this axis.
         """
 
         self.dist = 10
@@ -1070,6 +1077,10 @@ def view_init(self, elev=None, azim=None):
         else:
             self.azim = azim
 
+        self._vertical_axis = _api.check_getitem(
+            dict(x=0, y=1, z=2), vertical_axis=vertical_axis
+        )
+
     def set_proj_type(self, proj_type):
         """
         Set the projection type.
@@ -1083,47 +1094,60 @@ def set_proj_type(self, proj_type):
             'ortho': proj3d.ortho_transformation,
         }, proj_type=proj_type)
 
+    def _roll_to_vertical(self, arr):
+        """Roll arrays to match the different vertical axis."""
+        return np.roll(arr, self._vertical_axis - 2)
+
     def get_proj(self):
         """Create the projection matrix from the current viewing position."""
-        # elev stores the elevation angle in the z plane
-        # azim stores the azimuth angle in the x,y plane
-        #
-        # dist is the distance of the eye viewing point from the object
-        # point.
 
-        relev, razim = np.pi * self.elev/180, np.pi * self.azim/180
-
-        xmin, xmax = self.get_xlim3d()
-        ymin, ymax = self.get_ylim3d()
-        zmin, zmax = self.get_zlim3d()
-
-        # transform to uniform world coordinates 0-1, 0-1, 0-1
-        worldM = proj3d.world_transformation(xmin, xmax,
-                                             ymin, ymax,
-                                             zmin, zmax,
-                                             pb_aspect=self._box_aspect)
-
-        # look into the middle of the new coordinates
-        R = self._box_aspect / 2
+        # Transform to uniform world coordinates 0-1, 0-1, 0-1
+        box_aspect = self._roll_to_vertical(self._box_aspect)
+        worldM = proj3d.world_transformation(
+            *self.get_xlim3d(),
+            *self.get_ylim3d(),
+            *self.get_zlim3d(),
+            pb_aspect=box_aspect,
+        )
 
-        xp = R[0] + np.cos(razim) * np.cos(relev) * self.dist
-        yp = R[1] + np.sin(razim) * np.cos(relev) * self.dist
-        zp = R[2] + np.sin(relev) * self.dist
-        E = np.array((xp, yp, zp))
+        # Look into the middle of the new coordinates:
+        R = 0.5 * box_aspect
 
-        self.eye = E
-        self.vvec = R - E
+        # elev stores the elevation angle in the z plane
+        # azim stores the azimuth angle in the x,y plane
+        elev_rad = np.deg2rad(self.elev)
+        azim_rad = np.deg2rad(self.azim)
+
+        # Coordinates for a point that rotates around the box of data.
+        # p0, p1 corresponds to rotating the box only around the
+        # vertical axis.
+        # p2 corresponds to rotating the box only around the horizontal
+        # axis.
+        p0 = np.cos(elev_rad) * np.cos(azim_rad)
+        p1 = np.cos(elev_rad) * np.sin(azim_rad)
+        p2 = np.sin(elev_rad)
+
+        # When changing vertical axis the coordinates changes as well.
+        # Roll the values to get the same behaviour as the default:
+        ps = self._roll_to_vertical([p0, p1, p2])
+
+        # The coordinates for the eye viewing point. The eye is looking
+        # towards the middle of the box of data from a distance:
+        eye = R + self.dist * ps
+
+        # TODO: Is this being used somewhere? Can it be removed?
+        self.eye = eye
+        self.vvec = R - eye
         self.vvec = self.vvec / np.linalg.norm(self.vvec)
 
-        if abs(relev) > np.pi/2:
-            # upside down
-            V = np.array((0, 0, -1))
-        else:
-            V = np.array((0, 0, 1))
-        zfront, zback = -self.dist, self.dist
+        # Define which axis should be vertical. A negative value
+        # indicates the plot is upside down and therefore the values
+        # have been reversed:
+        V = np.zeros(3)
+        V[self._vertical_axis] = -1 if abs(elev_rad) > 0.5 * np.pi else 1
 
-        viewM = proj3d.view_transformation(E, R, V)
-        projM = self._projection(zfront, zback)
+        viewM = proj3d.view_transformation(eye, R, V)
+        projM = self._projection(-self.dist, self.dist)
         M0 = np.dot(viewM, worldM)
         M = np.dot(projM, M0)
         return M
diff --git a/lib/mpl_toolkits/mplot3d/axis3d.py b/lib/mpl_toolkits/mplot3d/axis3d.py
index c9d1f2801f3b..86d8e3b183a1 100644
--- a/lib/mpl_toolkits/mplot3d/axis3d.py
+++ b/lib/mpl_toolkits/mplot3d/axis3d.py
@@ -206,6 +206,45 @@ def _get_coord_info(self, renderer):
 
         return mins, maxs, centers, deltas, bounds_proj, highs
 
+    def _get_axis_line_edge_points(self, minmax, maxmin):
+        """Get the edge points for the black bolded axis line."""
+        # When changing vertical axis some of the axes has to be
+        # moved to the other plane so it looks the same as if the z-axis
+        # was the vertical axis.
+        mb = [minmax, maxmin]
+        mb_rev = mb[::-1]
+        mm = [[mb, mb_rev, mb_rev], [mb_rev, mb_rev, mb], [mb, mb, mb]]
+        mm = mm[self.axes._vertical_axis][self._axinfo["i"]]
+
+        juggled = self._axinfo["juggled"]
+        edge_point_0 = mm[0].copy()
+        edge_point_0[juggled[0]] = mm[1][juggled[0]]
+
+        edge_point_1 = edge_point_0.copy()
+        edge_point_1[juggled[1]] = mm[1][juggled[1]]
+
+        return edge_point_0, edge_point_1
+
+    def _get_tickdir(self):
+        """
+        Get the direction of the tick.
+
+        Returns
+        -------
+        tickdir : int
+            Index which indicates which coordinate the tick line will
+            align with.
+        """
+        # TODO: Move somewhere else where it's triggered less:
+        tickdirs_base = [v["tickdir"] for v in self._AXINFO.values()]
+        info_i = [v["i"] for v in self._AXINFO.values()]
+
+        i = self._axinfo["i"]
+        j = self.axes._vertical_axis - 2
+        # tickdir = [[1, 2, 1], [2, 2, 0], [1, 0, 0]][i]
+        tickdir = np.roll(info_i, -j)[np.roll(tickdirs_base, j)][i]
+        return tickdir
+
     def draw_pane(self, renderer):
         renderer.open_group('pane3d', gid=self.get_gid())
 
@@ -226,29 +265,27 @@ def draw_pane(self, renderer):
     @artist.allow_rasterization
     def draw(self, renderer):
         self.label._transform = self.axes.transData
-        renderer.open_group('axis3d', gid=self.get_gid())
+        renderer.open_group("axis3d", gid=self.get_gid())
 
         ticks = self._update_ticks()
 
+        # Get general axis information:
         info = self._axinfo
-        index = info['i']
+        index = info["i"]
+        juggled = info["juggled"]
 
         mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
 
-        # Determine grid lines
         minmax = np.where(highs, maxs, mins)
-        maxmin = np.where(highs, mins, maxs)
+        maxmin = np.where(~highs, maxs, mins)
 
-        # Draw main axis line
-        juggled = info['juggled']
-        edgep1 = minmax.copy()
-        edgep1[juggled[0]] = maxmin[juggled[0]]
+        # Create edge points for the black bolded axis line:
+        edgep1, edgep2 = self._get_axis_line_edge_points(minmax, maxmin)
 
-        edgep2 = edgep1.copy()
-        edgep2[juggled[1]] = maxmin[juggled[1]]
-        pep = np.asarray(
-            proj3d.proj_trans_points([edgep1, edgep2], self.axes.M))
-        centpt = proj3d.proj_transform(*centers, self.axes.M)
+        # Project the edge points along the current position and
+        # create the line:
+        pep = proj3d.proj_trans_points([edgep1, edgep2], self.axes.M)
+        pep = np.asarray(pep)
         self.line.set_data(pep[0], pep[1])
         self.line.draw(renderer)
 
@@ -325,6 +362,7 @@ def draw(self, renderer):
         # Three-letters (e.g., TFT, FTT) are short-hand for the array of bools
         # from the variable 'highs'.
         # ---------------------------------------------------------------------
+        centpt = proj3d.proj_transform(*centers, self.axes.M)
         if centpt[info['tickdir']] > pep[info['tickdir'], outerindex]:
             # if FT and if highs has an even number of Trues
             if (centpt[index] <= pep[index, outerindex]
@@ -370,8 +408,8 @@ def draw(self, renderer):
             self.gridlines.do_3d_projection()
             self.gridlines.draw(renderer)
 
-        # Draw ticks
-        tickdir = info['tickdir']
+        # Draw ticks:
+        tickdir = self._get_tickdir()
         tickdelta = deltas[tickdir]
         if highs[tickdir]:
             ticksign = 1
diff --git a/lib/mpl_toolkits/tests/test_mplot3d.py b/lib/mpl_toolkits/tests/test_mplot3d.py
index 9717d8bf1c58..edb5c978bb49 100644
--- a/lib/mpl_toolkits/tests/test_mplot3d.py
+++ b/lib/mpl_toolkits/tests/test_mplot3d.py
@@ -1523,3 +1523,93 @@ def test_scatter_spiral():
 
     # force at least 1 draw!
     fig.canvas.draw()
+
+
+@pytest.mark.parametrize(
+    "vertical_axis, proj_expected, axis_lines_expected, tickdirs_expected",
+    [
+        (
+            "z",
+            [
+                [0.0, 1.142857, 0.0, -0.571429],
+                [0.0, 0.0, 0.857143, -0.428571],
+                [0.0, 0.0, 0.0, -10.0],
+                [-1.142857, 0.0, 0.0, 10.571429],
+            ],
+            [
+                ([0.05617978, 0.06329114], [-0.04213483, -0.04746835]),
+                ([-0.06329114, 0.06329114], [-0.04746835, -0.04746835]),
+                ([-0.06329114, -0.06329114], [-0.04746835, 0.04746835]),
+            ],
+            [1, 0, 0],
+        ),
+        (
+            "y",
+            [
+                [1.142857, 0.0, 0.0, -0.571429],
+                [0.0, 0.857143, 0.0, -0.428571],
+                [0.0, 0.0, 0.0, -10.0],
+                [0.0, 0.0, -1.142857, 10.571429],
+            ],
+            [
+                ([0.06329114, -0.06329114], [-0.04746835, -0.04746835]),
+                ([-0.06329114, -0.06329114], [0.04746835, -0.04746835]),
+                ([0.05617978, 0.06329114], [-0.04213483, -0.04746835]),
+            ],
+            [2, 2, 0],
+        ),
+        (
+            "x",
+            [
+                [0.0, 0.0, 1.142857, -0.571429],
+                [0.857143, 0.0, 0.0, -0.428571],
+                [0.0, 0.0, 0.0, -10.0],
+                [0.0, -1.142857, 0.0, 10.571429],
+            ],
+            [
+                ([-0.06329114, -0.06329114], [-0.04746835, 0.04746835]),
+                ([0.06329114, 0.05617978], [-0.04746835, -0.04213483]),
+                ([0.06329114, -0.06329114], [-0.04746835, -0.04746835]),
+            ],
+            [1, 2, 1],
+        ),
+    ],
+)
+def test_view_init_vertical_axis(
+    vertical_axis, proj_expected, axis_lines_expected, tickdirs_expected
+):
+    """
+    Test the actual projection, axis lines and ticks matches expected values.
+
+    Parameters
+    ----------
+    vertical_axis : str
+        Axis to align vertically.
+    proj_expected : ndarray
+        Expected values from ax.get_proj().
+    axis_lines_expected : tuple of arrays
+        Edgepoints of the axis line. Expected values retrieved according
+        to ``ax.get_[xyz]axis().line.get_data()``.
+    tickdirs_expected : list of int
+        indexes indicating which axis to create a tick line along.
+    """
+    rtol = 2e-06
+    ax = plt.subplot(1, 1, 1, projection="3d")
+    ax.view_init(azim=0, elev=0, vertical_axis=vertical_axis)
+    ax.figure.canvas.draw()
+
+    # Assert the projection matrix:
+    proj_actual = ax.get_proj()
+    np.testing.assert_allclose(proj_expected, proj_actual, rtol=rtol)
+
+    for i, axis in enumerate([ax.get_xaxis(), ax.get_yaxis(), ax.get_zaxis()]):
+        # Assert black lines are correctly aligned:
+        axis_line_expected = axis_lines_expected[i]
+        axis_line_actual = axis.line.get_data()
+        np.testing.assert_allclose(axis_line_expected, axis_line_actual,
+                                   rtol=rtol)
+
+        # Assert ticks are correctly aligned:
+        tickdir_expected = tickdirs_expected[i]
+        tickdir_actual = axis._get_tickdir()
+        np.testing.assert_array_equal(tickdir_expected, tickdir_actual)