@@ -250,6 +250,12 @@ def __init__(self, baseline_images, extensions, tol,
250
250
self .savefig_kwargs = savefig_kwargs
251
251
self .style = style
252
252
253
+ def delayed_init (self , func ):
254
+ assert callable (func ), "func must be callable"
255
+ assert self .func is None , "it looks like same decorator used twice"
256
+ self .func = func
257
+ self .baseline_dir , self .result_dir = _image_directories (func )
258
+
253
259
def setup (self ):
254
260
func = self .func
255
261
self .setup_class ()
@@ -275,25 +281,26 @@ def copy_baseline(self, baseline, extension):
275
281
orig_expected_fname = baseline_path + '.pdf'
276
282
expected_fname = make_test_filename (os .path .join (
277
283
self .result_dir , os .path .basename (orig_expected_fname )), 'expected' )
278
- actual_fname = os .path .join (self .result_dir , baseline ) + '.' + extension
279
284
if os .path .exists (orig_expected_fname ):
280
285
shutil .copyfile (orig_expected_fname , expected_fname )
281
286
else :
282
287
xfail ("Do not have baseline image {0} because this "
283
288
"file does not exist: {1}" .format (expected_fname ,
284
289
orig_expected_fname ))
285
- return expected_fname , actual_fname
290
+ return expected_fname
286
291
287
292
def compare (self , idx , baseline , extension ):
288
293
__tracebackhide__ = True
289
- if self .baseline_dir is None :
290
- self .baseline_dir , self .result_dir = _image_directories (self .func )
291
- expected_fname , actual_fname = self .copy_baseline (baseline , extension )
292
294
fignum = plt .get_fignums ()[idx ]
293
295
fig = plt .figure (fignum )
296
+
294
297
if self .remove_text :
295
298
remove_ticks_and_titles (fig )
299
+
300
+ actual_fname = os .path .join (self .result_dir , baseline ) + '.' + extension
296
301
fig .savefig (actual_fname , ** self .savefig_kwargs )
302
+
303
+ expected_fname = self .copy_baseline (baseline , extension )
297
304
raise_on_image_difference (expected_fname , actual_fname , self .tol )
298
305
299
306
def nose_runner (self ):
@@ -324,7 +331,7 @@ def wrapper(idx, baseline, extension):
324
331
return wrapper
325
332
326
333
def __call__ (self , func ):
327
- self .func = func
334
+ self .delayed_init ( func )
328
335
if is_called_from_pytest ():
329
336
return copy_metadata (func , self .pytest_runner ())
330
337
else :
0 commit comments