8000 feat: Add `.close()` method to `Llama` class to explicitly free model… · mojowebs/llama-cpp-python@320a5d7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 320a5d7

Browse files
jkawamotoabetlen
andauthored
feat: Add .close() method to Llama class to explicitly free model from memory (abetlen#1513)
* feat: add explicit methods to free model This commit introduces a `close` method to both `Llama` and `_LlamaModel`, allowing users to explicitly free the model from RAM/VRAM. The previous implementation relied on the destructor of `_LlamaModel` to free the model. However, in Python, the timing of destructor calls is unclear—for instance, the `del` statement does not guarantee immediate invocation of the destructor. This commit provides an explicit method to release the model, which works immediately and allows the user to load another model without memory issues. Additionally, this commit implements a context manager in the `Llama` class, enabling the automatic closure of the `Llama` object when used with the `with` statement. * feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch This commit enables automatic resource management by implementing the `ContextManager` protocol in `_LlamaModel`, `_LlamaContext`, and `_LlamaBatch`. This ensures that resources are properly managed and released within a `with` statement, enhancing robustness and safety in resource handling. * feat: add ExitStack for Llama's internal class closure This update implements ExitStack to manage and close internal classes in Llama, enhancing efficient and safe resource management. * Use contextlib ExitStack and closing * Explicitly free model when closing resources on server --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
1 parent dbcf64c commit 320a5d7

File tree

3 files changed

+50
-30
lines changed

3 files changed

+50
-30
lines changed

llama_cpp/_internals.py

Lines changed: 32 additions & 24 deletions
F41A
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Sequence,
1010
)
1111
from dataclasses import dataclass, field
12+
from contextlib import ExitStack
1213

1314
import numpy as np
1415
import numpy.typing as npt
@@ -27,9 +28,6 @@ class _LlamaModel:
2728
"""Intermediate Python wrapper for a llama.cpp llama_model.
2829
NOTE: For stability it's recommended you use the Llama class instead."""
2930

30-
_llama_free_model = None
31-
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
32-
3331
def __init__(
3432
self,
3533
*,
@@ -40,8 +38,7 @@ def __init__(
4038
self.path_model = path_model
4139
self.params = params
4240
self.verbose = verbose
43-
44-
self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore
41+
self._exit_stack = ExitStack()
4542

4643
self.model = None
4744

@@ -56,11 +53,17 @@ def __init__(
5653
if self.model is None:
5754
raise ValueError(f"Failed to load model from file: {path_model}")
5855

59-
def __del__(self):
60-
if self.model is not None and self._llama_free_model is not None:
61-
self._llama_free_model(self.model)
56+
def free_model():
57+
if self.model is None:
58+
return
59+
llama_cpp.llama_free_model(self.model)
6260
self.model = None
6361

62+
self._exit_stack.callback(free_model)
63+
64+
def close(self):
65+
self._exit_stack.close()
66+
6467
def vocab_type(self) -> int:
6568
assert self.model is not None
6669
return llama_cpp.llama_vocab_type(self.model)
@@ -257,8 +260,6 @@ class _LlamaContext:
257260
"""Intermediate Python wrapper for a llama.cpp llama_context.
258261
NOTE: For stability it's recommended you use the Llama class instead."""
259262

260-
_llama_free = None
261-
262263
def __init__(
263264
self,
264265
*,
@@ -269,24 +270,28 @@ def __init__(
269270
self.model = model
270271
self.params = params
271272
self.verbose = verbose
273+
self._exit_stack = ExitStack()
272274

273-
self._llama_free = llama_cpp._lib.llama_free # type: ignore
274275
self.ctx = None
275276

276277
assert self.model.model is not None
277278

278-
self.ctx = llama_cpp.llama_new_context_with_model(
279-
self.model.model, self.params
280-
)
279+
self.ctx = llama_cpp.llama_new_context_with_model(self.model.model, self.params)
281280

282281
if self.ctx is None:
283282
raise ValueError("Failed to create llama_context")
284283

285-
def __del__(self):
286-
if self.ctx is not None and self._llama_free is not None:
287-
self._llama_free(self.ctx)
284+
def free_ctx():
285+
if self.ctx is None:
286+
return
287+
llama_cpp.llama_free(self.ctx)
288288
self.ctx = None
289289

290+
self._exit_stack.callback(free_ctx)
291+
292+
def close(self):
293+
self._exit_stack.close()
294+
290295
def n_ctx(self) -> int:
291296
assert self.ctx is not None
292297
return llama_cpp.llama_n_ctx(self.ctx)
@@ -501,28 +506,31 @@ def default_params():
501506

502507

503508
class _LlamaBatch:
504-
_llama_batch_free = None
505-
506509
def __init__(
507510
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
508511
):
509512
self._n_tokens = n_tokens
510513
self.embd = embd
511514
self.n_seq_max = n_seq_max
512515
self.verbose = verbose
513-
514-
self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore
516+
self._exit_stack = ExitStack()
515517

516518
self.batch = None
517519
self.batch = llama_cpp.llama_batch_init(
518520
self._n_tokens, self.embd, self.n_seq_max
519521
)
520522

521-
def __del__(self):
522-
if self.batch is not None and self._llama_batch_free is not None:
523-
self._llama_batch_free(self.batch)
523+
def free_batch():
524+
if self.batch is None:
525+
return
526+
llama_cpp.llama_batch_free(self.batch)
524527
self.batch = None
525528

529+
self._exit_stack.callback(free_batch)
530+
531+
def close(self):
532+
self._exit_stack.close()
533+
526534
def n_tokens(self) -> int:
527535
assert self.batch is not None
528536
return self.batch.n_tokens

llama_cpp/llama.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import typing
1010
import fnmatch
1111
import warnings
12+
import contextlib
1213
import multiprocessing
14+
from types import TracebackType
1315

1416
from typing import (
1517
List,
@@ -21,6 +23,7 @@
2123
Deque,
2224
Callable,
2325
Dict,
26+
Type,
2427
)
2528
from collections import deque
2629
from pathlib import Path
@@ -350,9 +353,11 @@ def __init__(
350353
if not os.path.exists(model_path):
351354
raise ValueError(f"Model path does not exist: {model_path}")
352355

353-
self._model = _LlamaModel(
356+
self._stack = contextlib.ExitStack()
357+
358+
self._model = self._stack.enter_context(contextlib.closing(_LlamaModel(
354359
path_model=self.model_path, params=self.model_params, verbose=self.verbose
355-
)
360+
)))
356361

357362
# Override tokenizer
358363
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
@@ -364,18 +369,18 @@ def __init__(
364369
self.context_params.n_ctx = self._model.n_ctx_train()
365370
self.context_params.n_batch = self.n_batch
366371

367-
self._ctx = _LlamaContext(
372+
self._ctx = self._stack.enter_context(contextlib.closing(_LlamaContext(
368373
model=self._model,
369374
params=self.context_params,
370375
verbose=self.verbose,
371-
)
376+
)))
372377

373-
self._batch = _LlamaBatch(
378+
self._batch = self._stack.enter_context(contextlib.closing(_LlamaBatch(
374379
n_tokens=self.n_batch,
375380
embd=0,
376381
n_seq_max=self.context_params.n_ctx,
377382
verbose=self.verbose,
378-
)
383+
)))
379384

380385
if self.lora_path:
381386
if self._model.apply_lora_from_file(
@@ -1959,6 +1964,10 @@ def pooling_type(self) -> str:
19591964
"""Return the pooling type."""
19601965
return self._ctx.pooling_type()
19611966

1967+
def close(self) -> None:
1968+
"""Explicitly free the model from memory."""
1969+
self._stack.close()
1970+
19621971
@staticmethod
19631972
def logits_to_logprobs(
19641973
logits: Union[npt.NDArray[np.single], List], axis: int = -1

llama_cpp/server/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def __call__(self, model: Optional[str] = None) -> llama_cpp.Llama:
4444
if self._current_model is not None:
4545
return self._current_model
4646

47+
if self._current_model:
48+
self._current_model.close()
4749
self._current_model = None
4850

4951
settings = self._model_settings_dict[model]
@@ -65,6 +67,7 @@ def __iter__(self):
6567

6668
def free(self):
6769
if self._current_model:
70+
self._current_model.close()
6871
del self._current_model
6972

7073
@staticmethod

0 commit comments

Comments
 (0)
0