From 887eb62760c53624fff8d8f03355a86613ffbfb4 Mon Sep 17 00:00:00 2001 From: Greg Lucas Date: Sun, 1 May 2022 21:10:51 -0600 Subject: [PATCH] TST: Add some tests for QuadMesh contains function * Update QuadMesh.get_cursor_data to handle multiple contains hits * Test that an empty array doesn't return any cursor_data * Test a few points in a standard QuadMesh * Test points within and around a concave QuadMesh --- lib/matplotlib/collections.py | 12 +--- lib/matplotlib/tests/test_collections.py | 80 +++++++++++++++++++++++- 2 files changed, 80 insertions(+), 12 deletions(-) diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index 2fca1d84e65a..fdf560885b91 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -2197,12 +2197,6 @@ def draw(self, renderer): def get_cursor_data(self, event): contained, info = self.contains(event) - if len(info["ind"]) == 1: - ind, = info["ind"] - array = self.get_array() - if array is not None: - return array[ind] - else: - return None - else: - return None + if contained and self.get_array() is not None: + return self.get_array().ravel()[info["ind"]] + return None diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index e4b0c234e4b5..879e0f22b24b 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -12,7 +12,8 @@ import matplotlib.path as mpath import matplotlib.transforms as mtransforms from matplotlib.collections import (Collection, LineCollection, - EventCollection, PolyCollection) + EventCollection, PolyCollection, + QuadMesh) from matplotlib.testing.decorators import check_figures_equal, image_comparison from matplotlib._api.deprecation import MatplotlibDeprecationWarning @@ -483,6 +484,81 @@ def test_picking(): assert_array_equal(indices['ind'], [0]) +def test_quadmesh_contains(): + x = np.arange(4) + X = x[:, None] * x[None, :] + + fig, ax = plt.subplots() + mesh = ax.pcolormesh(X) + fig.draw_without_rendering() + xdata, ydata = 0.5, 0.5 + x, y = mesh.get_transform().transform((xdata, ydata)) + mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y) + found, indices = mesh.contains(mouse_event) + assert found + assert_array_equal(indices['ind'], [0]) + + xdata, ydata = 1.5, 1.5 + x, y = mesh.get_transform().transform((xdata, ydata)) + mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y) + found, indices = mesh.contains(mouse_event) + assert found + assert_array_equal(indices['ind'], [5]) + + +def test_quadmesh_contains_concave(): + # Test a concave polygon, V-like shape + x = [[0, -1], [1, 0]] + y = [[0, 1], [1, -1]] + fig, ax = plt.subplots() + mesh = ax.pcolormesh(x, y, [[0]]) + fig.draw_without_rendering() + # xdata, ydata, expected + points = [(-0.5, 0.25, True), # left wing + (0, 0.25, False), # between the two wings + (0.5, 0.25, True), # right wing + (0, -0.25, True), # main body + ] + for point in points: + xdata, ydata, expected = point + x, y = mesh.get_transform().transform((xdata, ydata)) + mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y) + found, indices = mesh.contains(mouse_event) + assert found is expected + + +def test_quadmesh_cursor_data(): + x = np.arange(4) + X = x[:, None] * x[None, :] + + fig, ax = plt.subplots() + mesh = ax.pcolormesh(X) + # Empty array data + mesh._A = None + fig.draw_without_rendering() + xdata, ydata = 0.5, 0.5 + x, y = mesh.get_transform().transform((xdata, ydata)) + mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y) + # Empty collection should return None + assert mesh.get_cursor_data(mouse_event) is None + + # Now test adding the array data, to make sure we do get a value + mesh.set_array(np.ones((X.shape))) + assert_array_equal(mesh.get_cursor_data(mouse_event), [1]) + + +def test_quadmesh_cursor_data_multiple_points(): + x = [1, 2, 1, 2] + fig, ax = plt.subplots() + mesh = ax.pcolormesh(x, x, np.ones((3, 3))) + fig.draw_without_rendering() + xdata, ydata = 1.5, 1.5 + x, y = mesh.get_transform().transform((xdata, ydata)) + mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y) + # All quads are covering the same square + assert_array_equal(mesh.get_cursor_data(mouse_event), np.ones(9)) + + def test_linestyle_single_dashes(): plt.scatter([0, 1, 2], [0, 1, 2], linestyle=(0., [2., 2.])) plt.draw() @@ -749,8 +825,6 @@ def test_quadmesh_deprecated_signature( fig_test, fig_ref, flat_ref, kwargs): # test that the new and old quadmesh signature produce the same results # remove when the old QuadMesh.__init__ signature expires (v3.5+2) - from matplotlib.collections import QuadMesh - x = [0, 1, 2, 3.] y = [1, 2, 3.] X, Y = np.meshgrid(x, y)