@@ -387,35 +387,45 @@ def decorator(func):
387
387
388
388
_ , result_dir = _image_directories (func )
389
389
390
- if len (inspect .signature (func ).parameters ) == 2 :
391
- # Free-standing function.
392
- @pytest .mark .parametrize ("ext" , extensions )
393
- def wrapper (ext ):
394
- fig_test = plt .figure ("test" )
395
- fig_ref = plt .figure ("reference" )
396
- func (fig_test , fig_ref )
397
- test_image_path = result_dir / (func .__name__ + "." + ext )
398
- ref_image_path = (
399
- result_dir / (func .__name__ + "-expected." + ext ))
400
- fig_test .savefig (test_image_path )
401
- fig_ref .savefig (ref_image_path )
402
- _raise_on_image_difference (
403
- ref_image_path , test_image_path , tol = tol )
404
-
405
- elif len (inspect .signature (func ).parameters ) == 3 :
406
- # Method.
407
- @pytest .mark .parametrize ("ext" , extensions )
408
- def wrapper (self , ext ):
409
- fig_test = plt .figure ("test" )
410
- fig_ref = plt .figure ("reference" )
411
- func (self , fig_test , fig_ref )
412
- test_image_path = result_dir / (func .__name__ + "." + ext )
413
- ref_image_path = (
414
- result_dir / (func .__name__ + "-expected." + ext ))
415
- fig_test .savefig (test_image_path )
416
- fig_ref .savefig (ref_image_path )
417
- _raise_on_image_difference (
418
- ref_image_path , test_image_path , tol = tol )
390
+ @pytest .mark .parametrize ("ext" , extensions )
391
+ def wrapper (* args , ext , ** kwargs ):
392
+ fig_test = plt .figure ("test" )
393
+ fig_ref = plt .figure ("reference" )
394
+ func (* args , fig_test = fig_test , fig_ref = fig_ref , ** kwargs )
395
+ test_image_path = result_dir / (func .__name__ + "." + ext )
396
+ ref_image_path = result_dir / (
397
+ func .__name__ + "-expected." + ext
398
+ )
399
+ fig_test .savefig (test_image_path )
400
+ fig_ref .savefig (ref_image_path )
401
+ _raise_on_image_difference (
402
+ ref_image_path , test_image_path , tol = tol
403
+ )
404
+
405
+ sig = inspect .signature (func )
406
+ new_sig = sig .replace (
407
+ parameters = (
408
+ [
409
+ param
410
+ for param in sig .parameters .values ()
411
+ if param .name not in {"fig_test" , "fig_ref" }
412
+ ]
413
+ + [
414
+ inspect .Parameter (
415
+ "ext" , inspect .Parameter .POSITIONAL_OR_KEYWORD
416
+ )
417
+ ]
418
+ )
419
+ )
420
+ wrapper .__signature__ = new_sig
421
+
422
+ # reach a bit into pytest internals to hoist the marks from
423
+ # our wrapped function
424
+ from _pytest .mark .structures import get_unpacked_marks
425
+
426
+ wrapper .pytestmark = get_unpacked_marks (
427
+ wrapper
428
+ ) + get_unpacked_marks (func )
419
429
420
430
return wrapper
421
431
0 commit comments