8000 feat(low-level-api): Improve API static type-safety and performance (… · coderonion/llama-cpp-python@7f51b60 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7f51b60

Browse files
authored
feat(low-level-api): Improve API static type-safety and performance (abetlen#1205)
1 parent 0f8aa4a commit 7f51b60

File tree

5 files changed

+858
-743
lines changed

5 files changed

+858
-743
lines changed

llama_cpp/_internals.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def apply_lora_from_file(
108108
scale,
109109
path_base_model.encode("utf-8")
110110
if path_base_model is not None
111-
else llama_cpp.c_char_p(0),
111+
else ctypes.c_char_p(0),
112112
n_threads,
113113
)
114114

@@ -303,8 +303,8 @@ def decode(self, batch: "_LlamaBatch"):
303303
assert self.ctx is not None
304304
assert batch.batch is not None
305305
return_code = llama_cpp.llama_decode(
306-
ctx=self.ctx,
307-
batch=batch.batch,
306+
self.ctx,
307+
batch.batch,
308308
)
309309
if return_code != 0:
310310
raise RuntimeError(f"llama_decode returned {return_code}")
@@ -493,7 +493,7 @@ class _LlamaBatch:
493493
def __init__(
494494
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
495495
):
496-
self.n_tokens = n_tokens
496+
self._n_tokens = n_tokens
497497
self.embd = embd
498498
self.n_seq_max = n_seq_max
499499
self.verbose = verbose
@@ -502,7 +502,7 @@ def __init__(
502502

503503
self.batch = None
504504
self.batch = llama_cpp.llama_batch_init(
505-
self.n_tokens, self.embd, self.n_seq_max
505+
self._n_tokens, self.embd, self.n_seq_max
506506
)
507507

508508
def __del__(self):
@@ -570,12 +570,13 @@ def copy_logits(self, logits: npt.NDArray[np.single]):
570570
self.candidates.data = self.candidates_data.ctypes.data_as(
571571
llama_cpp.llama_token_data_p
572572
)
573-
self.candidates.sorted = llama_cpp.c_bool(False)
574-
self.candidates.size = llama_cpp.c_size_t(self.n_vocab)
573+
self.candidates.sorted = ctypes.c_bool(False)
574+
self.candidates.size = ctypes.c_size_t(self.n_vocab)
575575

576576

577577
# Python wrappers over common/common
578578
def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]:
579+
assert model.model is not None
579580
n_tokens = len(text) + 1 if add_bos else len(text)
580581
result = (llama_cpp.llama_token * n_tokens)()
581582
n_tokens = llama_cpp.llama_tokenize(

llama_cpp/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1818,7 +1818,7 @@ def load_state(self, state: LlamaState) -> None:
18181818
self.input_ids = state.input_ids.copy()
18191819
self.n_tokens = state.n_tokens
18201820
state_size = state.llama_state_size
1821-
LLamaStateArrayType = llama_cpp.c_uint8 * state_size
1821+
LLamaStateArrayType = ctypes.c_uint8 * state_size
18221822
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
18231823

18241824
if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size:

0 commit comments

Comments
 (0)
0