400 |
| - fig_test.savefig(test_image_path) | 401 |
| - fig_ref.savefig(ref_image_path) |
402 |
| - _raise_on_image_difference( |
403 |
| - ref_image_path, test_image_path, tol=tol) |
404 |
| - |
405 |
| - elif len(inspect.signature(func).parameters) == 3: |
406 |
| - # Method. |
407 |
| - @pytest.mark.parametrize("ext", extensions) |
408 |
| - def wrapper(self, ext): |
409 |
| - fig_test = plt.figure("test") |
410 |
| - fig_ref = plt.figure("reference") |
411 |
| - func(self, fig_test, fig_ref) |
412 |
| - test_image_path = result_dir / (func.__name__ + "." + ext) |
413 |
| - ref_image_path = ( |
414 |
| - result_dir / (func.__name__ + "-expected." + ext)) |
415 |
| - fig_test.savefig(test_image_path) |
416 |
| - fig_ref.savefig(ref_image_path) |
417 |
| - _raise_on_image_difference( |
418 |
| - ref_image_path, test_image_path, tol=tol) |
| 390 | + @pytest.mark.parametrize("ext", extensions) |
| 391 | + def wrapper(*args, ext, **kwargs): |
| 392 | + fig_test = plt.figure("test") |
| 393 | + fig_ref = plt.figure("reference") |
| 394 | + func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs) |
| 395 | + test_image_path = result_dir / (func.__name__ + "." + ext) |
| 396 | + ref_image_path = result_dir / ( |
| 397 | + func.__name__ + "-expected." + ext |
| 398 | + ) |
| 399 | + fig_test.savefig(test_image_path) |
| 400 | + fig_ref.savefig(ref_image_path) |
| 401 | + _raise_on_image_difference( |
| 402 | + ref_image_path, test_image_path, tol=tol |
| 403 | + ) |
| 404 | + |
| 405 | + sig = inspect.signature(func) |
| 406 | + new_sig = sig.replace( |
| 407 | + parameters=([param |
| 408 | + for param in sig.parameters.values() |
| 409 | + if param.name not in {"fig_test", "fig_ref"}] |
| 410 | + + [inspect.Parameter("ext", POSITIONAL_OR_KEYWORD)]) |
| 411 | + ) |
| 412 | + wrapper.__signature__ = new_sig |
| 413 | + |
| 414 | + # reach a bit into pytest internals to hoist the marks from |
| 415 | + # our wrapped function |
| 416 | + new_marks = getattr(func, "pytestmark", []) + wrapper.pytestmark |
| 417 | + wrapper.pytestmark = new_marks |
419 | 418 |
|
420 | 419 | return wrapper
|
421 | 420 |
|
|
0 commit comments