diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index 2e1d75c6d411..83e1e582abfd 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -2195,11 +2195,45 @@ def draw(self, renderer): renderer.close_group(self.__class__.__name__) self.stale = False + def contains(self, event): + # docstring inherited + x, y = event.xdata, event.ydata + + p = self._coordinates + p_a = p[:-1, :-1, :] + p_b = p[:-1, 1:, :] + p_c = p[1:, 1:, :] + p_d = p[1:, :-1, :] + # (y - y0) (x1 - x0) - (x - x0) (y1 - y0) + def side_of_line(x, y, p0, p1): + """ + Return the side of the line the point (x, y) is on + left: >0 + on: 0 + right: <0 + """ + return ((y - p0[..., 1]) * (p1[..., 0] - p0[..., 0]) + - (x - p0[..., 0]) * (p1[..., 1] - p0[..., 1])) + + # Winding number, can handle concave polys + # Algorithm from Dan Sunday + # https://web.archive.org/web/20130126163405/ + # http://geomalgorithms.com/a03-_inclusion.html + winding_number = np.zeros(p_a.shape[:-1]) + for (p0, p1) in zip([p_a, p_b, p_c, p_d], [p_b, p_c, p_d, p_a]): + winding_number += ((p0[..., 1] <= y) + & (p1[..., 1] > y) # upward crossing + & (side_of_line(x, y, p0, p1) > 0)) + + winding_number -= ((p0[..., 1] > y) + & (p1[..., 1] <= y) # downward crossing + & (side_of_line(x, y, p0, p1) < 0)) + + ind = np.nonzero((winding_number != 0).ravel())[0] + return len(ind) > 0, dict(ind=ind) + def get_cursor_data(self, event): contained, info = self.contains(event) - if len(info["ind"]) == 1: - ind, = info["ind"] - array = self.get_array() - return array[ind] if array else None - else: - return None + if contained and len(self.get_array()): + return self.get_array()[info["ind"]] + return None diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index e4b0c234e4b5..c8a9aeb4f5df 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -12,7 +12,8 @@ import matplotlib.path as mpath import matplotlib.transforms as mtransforms from matplotlib.collections import (Collection, LineCollection, - EventCollection, PolyCollection) + EventCollection, PolyCollection, + QuadMesh) from matplotlib.testing.decorators import check_figures_equal, image_comparison from matplotlib._api.deprecation import MatplotlibDeprecationWarning @@ -483,6 +484,44 @@ def test_picking(): assert_array_equal(indices['ind'], [0]) +def test_quadmesh_contains(): + n = 4 + x = np.arange(n) + X = x[:, None] * x[None, :] + + fig, ax = plt.subplots() + mesh = ax.pcolormesh(X) + mouse_event = SimpleNamespace(xdata=0, ydata=0) + found, indices = mesh.contains(mouse_event) + assert found + assert_array_equal(indices['ind'], [0]) + + mouse_event = SimpleNamespace(xdata=1.5, ydata=1.5) + found, indices = mesh.contains(mouse_event) + assert found + assert_array_equal(indices['ind'], [5]) + + # Test a concave polygon too, V-like shape + x = [[0, -1], [1, 0]] + y = [[0, 1], [1, -1]] + mesh = ax.pcolormesh(x, y, [[0]]) + points = [(-0.5, 0.25, True), # left wing + (0, 0.25, False), # between the two wings + (0.5, 0.25, True), # right wing + (0, -0.25, True), # main body + ] + for point in points: + x, y, expected = point + mouse_event = SimpleNamespace(xdata=x, ydata=y) + found, indices = mesh.contains(mouse_event) + assert found == expected + + # Smoke test an empty array, get_array() == None + coll = QuadMesh(np.ones((3, 3, 2))) + found, indices = coll.contains(mouse_event) + assert not found + + def test_linestyle_single_dashes(): plt.scatter([0, 1, 2], [0, 1, 2], linestyle=(0., [2., 2.])) plt.draw() @@ -749,8 +788,6 @@ def test_quadmesh_deprecated_signature( fig_test, fig_ref, flat_ref, kwargs): # test that the new and old quadmesh signature produce the same results # remove when the old QuadMesh.__init__ signature expires (v3.5+2) - from matplotlib.collections import QuadMesh - x = [0, 1, 2, 3.] y = [1, 2, 3.] X, Y = np.meshgrid(x, y)