8000 Use `pytest_pyfunc_call` hook to intercept figure · matplotlib/pytest-mpl@8aadc6f · GitHub
[go: up one dir, main page]

Skip to content

Commit 8aadc6f

Browse files
committed
Use pytest_pyfunc_call hook to intercept figure
Instead modifying the test function itself by wrapping it in a function which runs the tests, use pytest hooks to intercept the generated figure and then run the tests. This should be a more robust approach that doesn't need as many special cases to be hardcoded. I have also refactored the get_marker and get_compare functions/methods to simplify these.
1 parent cdb37c9 commit 8aadc6f

File tree

1 file changed

+103
-130
lines changed

1 file changed

+103
-130
lines changed

pytest_mpl/plugin.py

Lines changed: 103 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,11 @@
3333
import json
3434
import shutil
3535
import hashlib
36-
import inspect
3736
import logging
3837
import tempfile
3938
import warnings
4039
import contextlib
4140
from pathlib import Path
42-
from functools import wraps
4341
from urllib.request import urlopen
4442

4543
import pytest
@@ -83,6 +81,14 @@ def pathify(path):
8381
return Path(path + ext)
8482

8583

84+
def _pytest_pyfunc_call(obj, pyfuncitem):
85+
testfunction = pyfuncitem.obj
86+
funcargs = pyfuncitem.funcargs
87+
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
88+
obj.result = testfunction(**testargs)
89+
return True
90+
91+
8692
def pytest_report_header(config, startdir):
8793
import matplotlib
8894
import matplotlib.ft2font
@@ -211,13 +217,11 @@ def close_mpl_figure(fig):
211217
plt.close(fig)
212218

213219

214-
def get_marker(item, marker_name):
215-
if hasattr(item, 'get_closest_marker'):
216-
return item.get_closest_marker(marker_name)
217-
else:
218-
# "item.keywords.get" was deprecated in pytest 3.6
219-
# See https://docs.pytest.org/en/latest/mark.html#updating-code
220-
return item.keywords.get(marker_name)
220+
def get_compare(item):
221+
"""
222+
Return the mpl_image_compare marker for the given item.
223+
"""
224+
return item.get_closest_marker("mpl_image_compare")
221225

222226

223227
def path_is_not_none(apath):
@@ -278,20 +282,14 @@ def __init__(self,
278282
logging.basicConfig(level=level)
279283
self.logger = logging.getLogger('pytest-mpl')
280284

281-
def get_compare(self, item):
282-
"""
283-
Return the mpl_image_compare marker for the given item.
284-
"""
285-
return get_marker(item, 'mpl_image_compare')
286-
287285
def generate_filename(self, item):
288286
"""
289287
Given a pytest item, generate the figure filename.
290288
"""
291289
if self.config.getini('mpl-use-full-test-name'):
292290
filename = self.generate_test_name(item) + '.png'
293291
else:
294-
compare = self.get_compare(item)
292+
compare = get_compare(item)
295293
# Find test name to use as plot name
296294
filename = compare.kwargs.get('filename', None)
297295
if filename is None:
@@ -319,7 +317,7 @@ def baseline_directory_specified(self, item):
319317
"""
320318
Returns `True` if a non-default baseline directory is specified.
321319
"""
322-
compare = self.get_compare(item)
320+
compare = get_compare(item)
323321
item_baseline_dir = compare.kwargs.get('baseline_dir', None)
324322
return item_baseline_dir or self.baseline_dir or self.baseline_relative_dir
325323

@@ -330,7 +328,7 @@ def get_baseline_directory(self, item):
330328
Using the global and per-test configuration return the absolute
331329
baseline dir, if the baseline file is local else return base URL.
332330
"""
333-
compare = self.get_compare(item)
331+
compare = get_compare(item)
334332
baseline_dir = compare.kwargs.get('baseline_dir', None)
335333
if baseline_dir is None:
336334
if self.baseline_dir is None:
@@ -394,7 +392,7 @@ def generate_baseline_image(self, item, fig):
394392
"""
395393
Generate reference figures.
396394
"""
397-
compare = self.get_compare(item)
395+
compare = get_compare(item)
398396
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
399397

400398
if not os.path.exists(self.generate_dir):
@@ -413,7 +411,7 @@ def generate_image_hash(self, item, fig):
413411
For a `matplotlib.figure.Figure`, returns the SHA256 hash as a hexadecimal
414412
string.
415413
"""
416-
compare = self.get_compare(item)
414+
compare = get_compare(item)
417415
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
418416

419417
imgdata = io.BytesIO()
@@ -436,7 +434,7 @@ def compare_image_to_baseline(self, item, fig, result_dir, summary=None):
436434
if summary is None:
437435
summary = {}
438436

439-
compare = self.get_compare(item)
437+
compare = get_compare(item)
440438
tolerance = compare.kwargs.get('tolerance', 2)
441439
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
442440

@@ -510,7 +508,7 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
510508
if summary is None:
511509
summary = {}
512510

513-
compare = self.get_compare(item)
511+
compare = get_compare(item)
514512
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
515513

516514
if not self.results_hash_library_name:
@@ -582,11 +580,13 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
582580
return
583581
return summary['status_msg']
584582

585-
def pytest_runtest_setup(self, item): # noqa
583+
@pytest.hookimpl(hookwrapper=True)
584+
def pytest_runtest_call(self, item): # noqa
586585

587-
compare = self.get_compare(item)
586+
compare = get_compare(item)
588587

589588
if compare is None:
589+
yield
590590
return
591591

592592
import matplotlib.pyplot as plt
@@ -600,95 +600,82 @@ def pytest_runtest_setup(self, item): # noqa
600600
remove_text = compare.kwargs.get('remove_text', False)
601601
backend = compare.kwargs.get('backend', 'agg')
602602

603-
original = item.function
604-
605-
@wraps(item.function)
606-
def item_function_wrapper(*args, **kwargs):
607-
608-
with plt.style.context(style, after_reset=True), switch_backend(backend):
609-
610-
# Run test and get figure object
611-
if inspect.ismethod(original): # method
612-
# In some cases, for example if setup_method is used,
613-
# original appears to belong to an instance of the test
614-
# class that is not the same as args[0], and args[0] is the
615-
# one that has the correct attributes set up from setup_method
616-
# so we ignore original.__self__ and use args[0] instead.
617-
fig = original.__func__(*args, **kwargs)
618-
else: # function
619-
fig = original(*args, **kwargs)
620-
621-
if remove_text:
622-
remove_ticks_and_titles(fig)
623-
624-
test_name = self.generate_test_name(item)
625-
result_dir = self.make_test_results_dir(item)
626-
627-
summary = {
628-
'status': None,
629-
'image_status': None,
630-
'hash_status': None,
631-
'status_msg': None,
632-
'baseline_image': None,
633-
'diff_image': None,
634-
'rms': None,
635-
'tolerance': None,
636-
'result_image': None,
637-
'baseline_hash': None,
638-
'result_hash': None,
639-
}
640-
641-
# What we do now depends on whether we are generating the
642-
# reference images or simply running the test.
643-
if self.generate_dir is not None:
644-
summary['status'] = 'skipped'
645-
summary['image_status'] = 'generated'
646-
summary['status_msg'] = 'Skipped test, since generating image.'
647-
generate_image = self.generate_baseline_image(item, fig)
648-
if self.results_always: # Make baseline image available in HTML
649-
result_image = (result_dir / "baseline.png").absolute()
650-
shutil.copy(generate_image, result_image)
651-
summary['baseline_image'] = \
652-
result_image.relative_to(self.results_dir).as_posix()
653-
654-
if self.generate_hash_library is not None:
655-
summary['hash_status'] = 'generated'
656-
image_hash = self.generate_image_hash(item, fig)
657-
self._generated_hash_library[test_name] = image_hash
658-
summary['baseline_hash'] = image_hash
659-
660-
# Only test figures if not generating images
661-
if self.generate_dir is None:
662-
# Compare to hash library
663-
if self.hash_library or compare.kwargs.get('hash_library', None):
664-
msg = self.compare_image_to_hash_library(item, fig, result_dir, summary=summary)
665-
666-
# Compare against a baseline if specified
667-
else:
668-
msg = self.compare_image_to_baseline(item, fig, result_dir, summary=summary)
669-
670-
close_mpl_figure(fig)
671-
672-
if msg is None:
673-
if not self.results_always:
674-
shutil.rmtree(result_dir)
675-
for image_type in ['baseline_image', 'diff_image', 'result_image']:
676-
summary[image_type] = None # image no longer exists
677-
else:
678-
self._test_results[test_name] = summary
679-
pytest.fail(msg, pytrace=False)
603+
with plt.style.context(style, after_reset=True), switch_backend(backend):
604+
605+
# Run test and get figure object
606+
yield
607+
fig = self.result
608+
609+
if remove_text:
610+
remove_ticks_and_titles(fig)
611+
612+
test_name = self.generate_test_name(item)
613+
result_dir = self.make_test_results_dir(item)
614+
615+
summary = {
616+
'status': None,
617+
'image_status': None,
618+
'hash_status': None,
619+
'status_msg': None,
620+
'baseline_image': None,
621+
'diff_image': None,
622+
'rms': None,
623+
'tolerance': None,
624+
'result_image': None,
625+
'baseline_hash': None,
626+
'result_hash': None,
627+
}
628+
629+
# What we do now depends on whether we are generating the
630+
# reference images or simply running the test.
631+
if self.generate_dir is not None:
632+
summary['status'] = 'skipped'
633+
summary['image_status'] = 'generated'
634+
summary['status_msg'] = 'Skipped test, since generating image.'
635+
generate_image = self.generate_baseline_image(item, fig)
636+
if self.results_always: # Make baseline image available in HTML
637+
result_image = (result_dir / "baseline.png").absolute()
638+
shutil.copy(generate_image, result_image)
639+
summary['baseline_image'] = \
640+
result_image.relative_to(self.results_dir).as_posix()
641+
642+
if self.generate_hash_library is not None:
643+
summary['hash_status'] = 'generated'
644+
image_hash = self.generate_image_hash(item, fig)
645+
self._generated_hash_library[test_name] = image_hash
646+
summary['baseline_hash'] = image_hash
647+
648+
# Only test figures if not generating images
649+
if self.generate_dir is None:
650+
# Compare to hash library
651+
if self.hash_library or compare.kwargs.get('hash_library', None):
652+
msg = self.compare_image_to_hash_library(item, fig, result_dir, summary=summary)
653+
654+
# Compare against a baseline if specified
655+
else:
656+
msg = self.compare_image_to_baseline(item, fig, result_dir, summary=summary)
680657

681658
close_mpl_figure(fig)
682659

683-
self._test_results[test_name] = summary
660+
if msg is None:
661+
if not self.results_always:
662+
shutil.rmtree(result_dir)
663+
for image_type in ['baseline_image', 'diff_image', 'result_image']:
664+
summary[image_type] = None # image no longer exists
665+
else:
666+
self._test_results[test_name] = summary
667+
pytest.fail(msg, pytrace=False)
668+
669+
close_mpl_figure(fig)
684670

685-
if summary['status'] == 'skipped':
686-
pytest.skip(summary['status_msg'])
671+
self._test_results[test_name] = summary
687672

688-
if item.cls is not None:
689-
setattr(item.cls, item.function.__name__, item_function_wrapper)
690-
else:
691-
item.obj = item_function_wrapper
673+
if summary['status'] == 'skipped':
674+
pytest.skip(summary['status_msg'])
675+
676+
@pytest.hookimpl(tryfirst=True)
677+
def pytest_pyfunc_call(self, pyfuncitem):
678+
return _pytest_pyfunc_call(self, pyfuncitem)
692679

693680
def generate_summary_json(self):
694681
json_file = self.results_dir / 'results.json'
@@ -742,26 +729,12 @@ class FigureCloser:
742729
def __init__(self, config):
743730
self.config = config
744731

745-
def pytest_runtest_setup(self, item):
746-
747-
compare = get_marker(item, 'mpl_image_compare')
748-
749-
if compare is None:
750-
return
751-
752-
original = item.function
753-
754-
@wraps(item.function)
755-
def item_function_wrapper(*args, **kwargs):
756-
757-
if inspect.ismethod(original): # method
758-
fig = original.__func__(*args, **kwargs)
759-
else: # function
760-
fig = original(*args, **kwargs)
761-
762-
close_mpl_figure(fig)
732+
@pytest.hookimpl(hookwrapper=True)
733+
def pytest_runtest_call(self, item):
734+
yield
735+
if get_compare(item) is not None:
736+
close_mpl_figure(self.result)
763737

764-
if item.cls is not None:
765-
setattr(item.cls, item.function.__name__, item_function_wrapper)
766-
else:
767-
item.obj = item_function_wrapper
738+
@pytest.hookimpl(tryfirst=True)
739+
def pytest_pyfunc_call(self, pyfuncitem):
740+
return _pytest_pyfunc_call(self, pyfuncitem)

0 commit comments

Comments
 (0)
0