18
18
from matplotlib import ft2font
19
19
from matplotlib import pyplot as plt
20
20
from matplotlib import ticker
21
-
22
- from .compare import comparable_formats , compare_images , make_test_filename
21
+ from .compare import compare_images , make_test_filename , _skip_if_uncomparable
23
22
from .exceptions import ImageComparisonFailure
24
23
25
24
@@ -137,54 +136,33 @@ def _raise_on_image_difference(expected, actual, tol):
137
136
% err )
138
137
139
138
140
- def _skip_if_format_is_uncomparable (extension ):
141
- import pytest
142
- return pytest .mark .skipif (
143
- extension not in comparable_formats (),
144
- reason = 'Cannot compare {} files on this system' .format (extension ))
145
-
146
-
147
- def _mark_skip_if_format_is_uncomparable (extension ):
148
- import pytest
149
- if isinstance (extension , str ):
150
- name = extension
151
- marks = []
152
- elif isinstance (extension , tuple ):
153
- # Extension might be a pytest ParameterSet instead of a plain string.
154
- # Unfortunately, this type is not exposed, so since it's a namedtuple,
155
- # check for a tuple instead.
156
- name , = extension .values
157
- marks = [* extension .marks ]
158
- else :
159
- # Extension might be a pytest marker instead of a plain string.
160
- name , = extension .args
161
- marks = [extension .mark ]
162
- return pytest .param (name ,
163
- marks = [* marks , _skip_if_format_is_uncomparable (name )])
164
-
165
-
166
- class _ImageComparisonBase :
139
+ def _make_image_comparator (func = None ,
140
+ baseline_images = None , * , extension = None , tol = 0 ,
141
+ remove_text = False , savefig_kwargs = None ):
167
142
"""
168
- Image comparison base class
143
+ Image comparison base helper.
169
144
170
- This class provides *just* the comparison-related functionality and avoids
145
+ This helper provides *just* the comparison-related functionality and avoids
171
146
any code that would be specific to any testing framework.
172
147
"""
148
+ if func is None :
149
+ return functools .partial (
150
+ _make_image_comparator ,
151
+ baseline_images = baseline_images , extension = extension , tol = tol ,
152
+ remove_text = remove_text , savefig_kwargs = savefig_kwargs )
173
153
174
- def __init__ (self , func , tol , remove_text , savefig_kwargs ):
175
- self .func = func
176
- self .baseline_dir , self .result_dir = _image_directories (func )
177
- self .tol = tol
178
- self .remove_text = remove_text
179
- self .savefig_kwargs = savefig_kwargs
154
+ if savefig_kwargs is None :
155
+ savefig_kwargs = {}
180
156
181
- def copy_baseline (self , baseline , extension ):
182
- baseline_path = self .baseline_dir / baseline
157
+ baseline_dir , result_dir = _image_directories (func )
158
+
159
+ def _copy_baseline (baseline ):
160
+ baseline_path = baseline_dir / baseline
183
161
orig_expected_path = baseline_path .with_suffix (f'.{ extension } ' )
184
162
if extension == 'eps' and not orig_expected_path .exists ():
185
163
orig_expected_path = orig_expected_path .with_suffix ('.pdf' )
186
164
expected_fname = make_test_filename (
187
- self . result_dir / orig_expected_path .name , 'expected' )
165
+ result_dir / orig_expected_path .name , 'expected' )
188
166
try :
189
167
# os.symlink errors if the target already exists.
190
168
with contextlib .suppress (OSError ):
@@ -200,24 +178,33 @@ def copy_baseline(self, baseline, extension):
200
178
f"{ orig_expected_path } " ) from err
201
179
return expected_fname
202
180
203
- def compare (self , idx , baseline , extension ):
181
+ @functools .wraps (func )
182
+ def wrapper (* args , ** kwargs ):
204
183
__tracebackhide__ = True
205
- fignum = plt .get_fignums ()[idx ]
206
- fig = plt .figure (fignum )
207
-
208
- if self .remove_text :
209
- remove_ticks_and_titles (fig )
210
-
211
- actual_path = (self .result_dir / baseline ).with_suffix (f'.{ extension } ' )
212
- kwargs = self .savefig_kwargs .copy ()
213
- if extension == 'pdf' :
214
- kwargs .setdefault ('metadata' ,
215
- {'Creator' : None , 'Producer' : None ,
216
- 'CreationDate' : None })
217
- fig .savefig (actual_path , ** kwargs )
218
-
219
- expected_path = self .copy_baseline (baseline , extension )
220
- _raise_on_image_difference (expected_path , actual_path , self .tol )
184
+ _skip_if_uncomparable (extension )
185
+
186
+ func (* args , ** kwargs )
187
+
188
+ fignums = plt .get_fignums ()
189
+ assert len (fignums ) == len (baseline_images ), (
190
+ "Test generated {} images but there are {} baseline images"
191
+ .format (len (fignums ), len (baseline_images )))
192
+ for baseline_image , fignum in zip (baseline_images , fignums ):
193
+ fig = plt .figure (fignum )
194
+ if remove_text :
195
+ remove_ticks_and_titles (fig )
196
+ actual_path = ((result_dir / baseline_image )
197
+ .with_suffix (f'.{ extension } ' ))
198
+ kwargs = savefig_kwargs .copy ()
199
+ if extension == 'pdf' :
200
+ kwargs .setdefault ('metadata' ,
201
+ {'Creator' : None , 'Producer' : None ,
202
+ 'CreationDate' : None })
203
+ fig .savefig (actual_path , ** kwargs )
204
+ expected_path = _copy_baseline (baseline_image )
205
+ _raise_on_image_difference (expected_path , actual_path , tol )
206
+
207
+ return wrapper
221
208
222
209
223
210
def _pytest_image_comparison (baseline_images , extensions , tol ,
@@ -233,8 +220,6 @@ def _pytest_image_comparison(baseline_images, extensions, tol,
233
220
"""
234
221
import pytest
235
222
236
- extensions = map (_mark_skip_if_format_is_uncomparable , extensions )
237
-
238
223
def decorator (func ):
239
224
@functools .wraps (func )
240
225
# Parameter indirection; see docstring above and comment below.
@@ -247,23 +232,19 @@ def decorator(func):
247
232
@functools .wraps (func )
248
233
def wrapper (* args , ** kwargs ):
249
234
__tracebackhide__ = True
250
- img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
251
- savefig_kwargs = savefig_kwargs )
252
- matplotlib .testing .set_font_settings_for_testing ()
253
- func (* args , ** kwargs )
254
-
255
235
# Parameter indirection:
256
236
# This is hacked on via the mpl_image_comparison_parameters fixture
257
237
# so that we don't need to modify the function's real signature for
258
238
# any parametrization. Modifying the signature is very very tricky
259
239
# and likely to confuse pytest.
260
240
baseline_images , extension = func .parameters
261
241
262
- assert len (plt .get_fignums ()) == len (baseline_images ), (
263
- "Test generated {} images but there are {} baseline images"
264
- .format (len (plt .get_fignums ()), len (baseline_images )))
265
- for idx , baseline in enumerate (baseline_images ):
266
- img .compare (idx , baseline , extension )
242
+ matplotlib .testing .set_font_settings_for_testing ()
243
+ comparator = _make_image_comparator (
244
+ func ,
245
+ baseline_images = baseline_images , extension = extension , tol = tol ,
246
+ remove_text = remove_text , savefig_kwargs = savefig_kwargs )
247
+ comparator (* args , ** kwargs )
267
248
268
249
return wrapper
269
250
@@ -347,8 +328,6 @@ def image_comparison(baseline_images, extensions=None, tol=0,
347
328
if extensions is None :
348
329
# Default extensions to test, if not set via baseline_images.
349
330
extensions = ['png' , 'pdf' , 'svg' ]
350
- if savefig_kwarg is None :
351
- savefig_kwarg = dict () # default no kwargs to savefig
352
331
return _pytest_image_comparison (
353
332
baseline_images = baseline_images , extensions = extensions , tol = tol ,
354
333
freetype_version = freetype_version , remove_text = remove_text ,
@@ -400,6 +379,7 @@ def decorator(func):
400
379
@pytest .mark .parametrize ("ext" , extensions )
401
380
def wrapper (* args , ** kwargs ):
402
381
ext = kwargs ['ext' ]
382
+ _skip_if_uncomparable (ext )
403
383
if 'ext' not in old_sig .parameters :
404
384
kwargs .pop ('ext' )
405
385
request = kwargs ['request' ]
0 commit comments