diff --git a/galleries/examples/lines_bars_and_markers/multicolored_line.py b/galleries/examples/lines_bars_and_markers/multicolored_line.py index 8c72d28e9e67..3a71225d0112 100644 --- a/galleries/examples/lines_bars_and_markers/multicolored_line.py +++ b/galleries/examples/lines_bars_and_markers/multicolored_line.py @@ -21,7 +21,7 @@ from matplotlib.collections import LineCollection -def colored_line(x, y, c, ax, **lc_kwargs): +def colored_line(x, y, c, ax=None, **lc_kwargs): """ Plot a line with a color specified along the line by a third value. @@ -36,8 +36,8 @@ def colored_line(x, y, c, ax, **lc_kwargs): The horizontal and vertical coordinates of the data points. c : array-like The color values, which should be the same size as x and y. - ax : Axes - Axis object on which to plot the colored line. + ax : matplotlib.axes.Axes, optional + The axes to plot on. If not provided, the current axes will be used. **lc_kwargs Any additional arguments to pass to matplotlib.collections.LineCollection constructor. This should not include the array keyword argument because @@ -49,36 +49,32 @@ def colored_line(x, y, c, ax, **lc_kwargs): The generated line collection representing the colored line. """ if "array" in lc_kwargs: - warnings.warn('The provided "array" keyword argument will be overridden') + warnings.warn( + 'The provided "array" keyword argument will be overridden', + UserWarning, + stacklevel=2, + ) + + xy = np.stack((x, y), axis=-1) + xy_mid = np.concat( + (xy[0, :][None, :], (xy[:-1, :] + xy[1:, :]) / 2, xy[-1, :][None, :]), axis=0 + ) + segments = np.stack((xy_mid[:-1, :], xy, xy_mid[1:, :]), axis=-2) + # Note that + # segments[0, :, :] is [xy[0, :], xy[0, :], (xy[0, :] + xy[1, :]) / 2] + # segments[i, :, :] is [(xy[i - 1, :] + xy[i, :]) / 2, xy[i, :], + # (xy[i, :] + xy[i + 1, :]) / 2] if i not in {0, len(x) - 1} + # segments[-1, :, :] is [(xy[-2, :] + xy[-1, :]) / 2, xy[-1, :], xy[-1, :]] + + lc_kwargs["array"] = c + lc = LineCollection(segments, **lc_kwargs) - # Default the capstyle to butt so that the line segments smoothly line up - default_kwargs = {"capstyle": "butt"} - default_kwargs.update(lc_kwargs) - - # Compute the midpoints of the line segments. Include the first and last points - # twice so we don't need any special syntax later to handle them. - x = np.asarray(x) - y = np.asarray(y) - x_midpts = np.hstack((x[0], 0.5 * (x[1:] + x[:-1]), x[-1])) - y_midpts = np.hstack((y[0], 0.5 * (y[1:] + y[:-1]), y[-1])) - - # Determine the start, middle, and end coordinate pair of each line segment. - # Use the reshape to add an extra dimension so each pair of points is in its - # own list. Then concatenate them to create: - # [ - # [(x1_start, y1_start), (x1_mid, y1_mid), (x1_end, y1_end)], - # [(x2_start, y2_start), (x2_mid, y2_mid), (x2_end, y2_end)], - # ... - # ] - coord_start = np.column_stack((x_midpts[:-1], y_midpts[:-1]))[:, np.newaxis, :] - coord_mid = np.column_stack((x, y))[:, np.newaxis, :] - coord_end = np.column_stack((x_midpts[1:], y_midpts[1:]))[:, np.newaxis, :] - segments = np.concatenate((coord_start, coord_mid, coord_end), axis=1) - - lc = LineCollection(segments, **default_kwargs) - lc.set_array(c) # set the colors of each segment + # Plot the line collection to the axes + ax = ax or plt.gca() + ax.add_collection(lc) + ax.autoscale_view() - return ax.add_collection(lc) + return lc # -------------- Create and show plot -------------- @@ -93,11 +89,6 @@ def colored_line(x, y, c, ax, **lc_kwargs): lines = colored_line(x, y, color, ax1, linewidth=10, cmap="plasma") fig1.colorbar(lines) # add a color legend -# Set the axis limits and tick positions -ax1.set_xlim(-1, 1) -ax1.set_ylim(-1, 1) -ax1.set_xticks((-1, 0, 1)) -ax1.set_yticks((-1, 0, 1)) ax1.set_title("Color at each point") plt.show()