@@ -253,9 +253,11 @@ def export(self, out_dir):
253
253
254
254
inp_stats = self .collect_inputs_stats ()
255
255
torch .save (inp_stats , out_dir / 'inputs_stats.pth' )
256
+ torch .cuda .empty_cache ()
256
257
257
258
out_stats = self .collect_outputs_stats ()
258
259
torch .save (out_stats , out_dir / 'outputs_stats.pth' )
260
+ torch .cuda .empty_cache ()
259
261
260
262
def calibrate (self , data ):
261
263
"""Forward pass through the model in inference mode with given data."""
@@ -267,6 +269,7 @@ def calibrate(self, data):
267
269
model = self .model .model
268
270
with torch .inference_mode ():
269
271
_ = model (data .to (self .device ))
272
+ torch .cuda .empty_cache ()
270
273
271
274
def __enter__ (self ):
272
275
"""Prepares the Calibration object for a 'with' statement by
@@ -440,6 +443,7 @@ def export(self, out_dir):
440
443
inputs_stats ['absmean' ][name ] = obs .absmean_val
441
444
inputs_stats ['ratios' ][name ] = obs .ratio
442
445
torch .save (inputs_stats , out_dir / 'inputs_stats.pth' )
446
+ torch .cuda .empty_cache ()
443
447
444
448
def _wrap_decoder_layers_for_search (self ):
445
449
"""Method to wrap the decoder layers' forward functions for observing
0 commit comments