8000 MNT/TST: generalize check_figures_equal to work with pytest.marks · matplotlib/matplotlib@c0462a5 · GitHub
[go: up one dir, main page]

Skip to content

Commit c0462a5

Browse files
committed
MNT/TST: generalize check_figures_equal to work with pytest.marks
Generalize (at the cost of reaching a little bit into pytest internals) the check_figures_equal decorator so it works as expected with `pytest.mark.paramatarize`
1 parent 74e9dc7 commit c0462a5

File tree

2 files changed

+53
-31
lines changed

2 files changed

+53
-31
lines changed

lib/matplotlib/testing/decorators.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -387,35 +387,45 @@ def decorator(func):
387387

388388
_, result_dir = _image_directories(func)
389389

390-
if len(inspect.signature(func).parameters) == 2:
391-
# Free-standing function.
392-
@pytest.mark.parametrize("ext", extensions)
393-
def wrapper(ext):
394-
fig_test = plt.figure("test")
395-
fig_ref = plt.figure("reference")
396-
func(fig_test, fig_ref)
397-
test_image_path = result_dir / (func.__name__ + "." + ext)
398-
ref_image_path = (
399-
result_dir / (func.__name__ + "-expected." + ext))
400-
fig_test.savefig(test_image_path)
401-
fig_ref.savefig(ref_image_path)
402-
_raise_on_image_difference(
403-
ref_image_path, test_image_path, tol=tol)
404-
405-
elif len(inspect.signature(func).parameters) == 3:
406-
# Method.
407-
@pytest.mark.parametrize("ext", extensions)
408-
def wrapper(self, ext):
409-
fig_test = plt.figure("test")
410-
fig_ref = plt.figure("reference")
411-
func(self, fig_test, fig_ref)
412-
test_image_path = result_dir / (func.__name__ + "." + ext)
413-
ref_image_path = (
414-
result_dir / (func.__name__ + "-expected." + ext))
415-
fig_test.savefig(test_image_path)
416-
fig_ref.savefig(ref_image_path)
417-
_raise_on_image_difference(
418-
ref_image_path, test_image_path, tol=tol)
390+
@pytest.mark.parametrize("ext", extensions)
391+
def wrapper(*args, ext, **kwargs):
392+
fig_test = plt.figure("test")
393+
fig_ref = plt.figure("reference")
394+
func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs)
395+
test_image_path = result_dir / (func.__name__ + "." + ext)
396+
ref_image_path = result_dir / (
397+
func.__name__ + "-expected." + ext
398+
)
399+
fig_test.savefig(test_image_path)
400+
fig_ref.savefig(ref_image_path)
401+
_raise_on_image_difference(
402+
ref_image_path, test_image_path, tol=tol
403+
)
404+
405+
sig = inspect.signature(func)
406+
new_sig = sig.replace(
407+
parameters=(
408+
[
409+
param
410+
for param in sig.parameters.values()
411+
if param.name not in {"fig_test", "fig_ref"}
412+
]
413+
+ [
414+
inspect.Parameter(
415+
"ext", inspect.Parameter.POSITIONAL_OR_KEYWORD
416+
)
417+
]
418+
)
419+
)
420+
wrapper.__signature__ = new_sig
421+
422+
# reach a bit into pytest internals to hoist the marks from
423+
# our wrapped function
424+
from _pytest.mark.structures import get_unpacked_marks
425+
426+
wrapper.pytestmark = get_unpacked_marks(
427+
wrapper
428+
) + get_unpacked_marks(func)
419429

420430
return wrapper
421431

lib/matplotlib/tests/test_testing.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
import warnings
22
import pytest
3+
from matplotlib.testing.decorators import check_figures_equal
34

4-
@pytest.mark.xfail(strict=True,
5-
reason="testing that warnings fail tests")
5+
6+
@pytest.mark.xfail(
7+
strict=True, reason="testing that warnings fail tests"
8+
)
69
def test_warn_to_fail():
710
warnings.warn("This should fail the test")
11+
12+
13+
@pytest.mark.parametrize("a", [1])
14+
@check_figures_equal(extensions=["png"])
15+
@pytest.mark.parametrize("b", [1])
16+
def test_paramatirize_with_check_figure_equal(
17+
a, fig_ref, b, fig_test
18+
):
19+
assert a == b

0 commit comments

Comments
 (0)
0