8000 Make Llama instance pickleable. Closes #27 · hrubanj/llama-cpp-python@e96a5c5 · GitHub
[go: up one dir, main page]

Skip to content

Commit e96a5c5

Browse files
committed
Make Llama instance pickleable. Closes abetlen#27
1 parent 152e469 commit e96a5c5

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

llama_cpp/llama.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,45 @@ def __del__(self):
651651
llama_cpp.llama_free(self.ctx)
652652
self.ctx = None
653653

654+
def __getstate__(self):
655+
return dict(
656+
verbose=self.verbose,
657+
model_path=self.model_path,
658+
n_ctx=self.params.n_ctx,
659+
n_parts=self.params.n_parts,
660+
seed=self.params.seed,
661+
f16_kv=self.params.f16_kv,
662+
logits_all=self.params.logits_all,
663+
vocab_only=self.params.vocab_only,
664+
use_mlock=self.params.use_mlock,
665+
embedding=self.params.embedding,
666+
last_n_tokens_size=self.last_n_tokens_size,
667+
last_n_tokens_data=self.last_n_tokens_data,
668+
tokens_consumed=self.tokens_consumed,
669+
n_batch=self.n_batch,
670+
n_threads=self.n_threads,
671+
)
672+
673+
def __setstate__(self, state):
674+
self.__init__(
675+
model_path=state["model_path"],
676+
n_ctx=state["n_ctx"],
677+
n_parts=state["n_parts"],
678+
seed=state["seed"],
679+
f16_kv=state["f16_kv"],
680+
logits_all=state["logits_all"],
681+
vocab_only=state["vocab_only"],
682+
use_mlock=state["use_mlock"],
683+
embedding=state["embedding"],
684+
n_threads=state["n_threads"],
685+
n_batch=state["n_batch"],
686+
last_n_tokens_size=state["last_n_tokens_size"],
687+
verbose=state["verbose"],
688+
)
689+
self.last_n_tokens_data=state["last_n_tokens_data"]
690+
self.tokens_consumed=state["tokens_consumed"]
691+
692+
654693
@staticmethod
655694
def token_eos() -> llama_cpp.llama_token:
656695
"""Return the end-of-sequence token."""

tests/test_llama.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,20 @@ def mock_sample(*args, **kwargs):
7777
chunks = llama.create_completion(text, max_tokens=2, stream=True)
7878
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j"
7979
assert completion["choices"][0]["finish_reason"] == "length"
80+
81+
82+
def test_llama_pickle():
83+
import pickle
84+
import tempfile
85+
fp = tempfile.TemporaryFile()
86+
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
87+
pickle.dump(llama, fp)
88+
fp.seek(0)
89+
llama = pickle.load(fp)
90+
91+
assert llama
92+
assert llama.ctx is not None
93+
94+
text = b"Hello World"
95+
96+
assert llama.detokenize(llama.tokenize(text)) == text

0 commit comments

Comments
 (0)
0