8000 Call cuda empty_cache to prevent OOM when quantizing model (#2671) · InternLM/lmdeploy@28c8b79 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2 8000 8c8b79

Browse files
AllentDanlvhan028
authored andcommitted
Call cuda empty_cache to prevent OOM when quantizing model (#2671)
* Call cuda empty_cache to prevent OOM when quantizing model * empty cache during export and after forward
1 parent 5ea819f commit 28c8b79

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

lmdeploy/lite/quantization/calibration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,11 @@ def export(self, out_dir):
253253

254254
inp_stats = self.collect_inputs_stats()
255255
torch.save(inp_stats, out_dir / 'inputs_stats.pth')
256+
torch.cuda.empty_cache()
256257

257258
out_stats = self.collect_outputs_stats()
258259
torch.save(out_stats, out_dir / 'outputs_stats.pth')
260+
torch.cuda.empty_cache()
259261

260262
def calibrate(self, data):
261263
"""Forward pass through the model in inference mode with given data."""
@@ -267,6 +269,7 @@ def calibrate(self, data):
267269
model = self.model.model
268270
with torch.inference_mode():
269271
_ = model(data.to(self.device))
272+
torch.cuda.empty_cache()
270273

271274
def __enter__(self):
272275
"""Prepares the Calibration object for a 'with' statement by
@@ -440,6 +443,7 @@ def export(self, out_dir):
440443
inputs_stats['absmean'][name] = obs.absmean_val
441444
inputs_stats['ratios'][name] = obs.ratio
442445
torch.save(inputs_stats, out_dir / 'inputs_stats.pth')
446+
torch.cuda.empty_cache()
443447

444448
def _wrap_decoder_layers_for_search(self):
445449
"""Method to wrap the decoder layers' forward functions for observing

0 commit comments

Comments
 (0)
0