8000 misc: use typesafe byref for internal classes · coderonion/llama-cpp-python@b9aca61 · GitHub
[go: up one dir, main page]

Skip to content

Commit b9aca61

Browse files
committed
misc: use typesafe byref for internal classes
1 parent a0ce429 commit b9aca61

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

llama_cpp/_internals.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def rope_freq_scale_train(self) -> float:
8282
def desc(self) -> str:
8383
assert self.model is not None
8484
buf = ctypes.create_string_buffer(1024)
85-
llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore
85+
llama_cpp.llama_model_desc(self.model, buf, 1024)
8686
return buf.value.decode("utf-8")
8787

8888
def size(self) -> int:
@@ -184,7 +184,7 @@ def tokenize(self, text: bytes, add_bos: bool, special: bool):
184184
def token_to_piece(self, token: int) -> bytes:
185185
assert self.model is not None
186186
buf = ctypes.create_string_buffer(32)
187-
llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore
187+
llama_cpp.llama_token_to_piece(self.model, token, buf, 32)
188188
return bytes(buf)
189189

190190
def detokenize(self, tokens: List[int]) -> bytes:
@@ -349,7 +349,7 @@ def sample_repetition_penalties(
349349
assert self.ctx is not None
350350
llama_cpp.llama_sample_repetition_penalties(
351351
self.ctx,
352-
ctypes.byref(candidates.candidates), # type: ignore
352+
llama_cpp.byref(candidates.candidates),
353353
last_tokens_data,
354354
penalty_last_n,
355355
penalty_repeat,
@@ -367,7 +367,7 @@ def sample_classifier_free_guidance(
367367
assert guidance_ctx.ctx is not None
368368
llama_cpp.llama_sample_classifier_free_guidance(
369369
self.ctx,
370-
ctypes.byref(candidates.candidates), # type: ignore
370+
llama_cpp.byref(candidates.candidates),
371371
guidance_ctx.ctx,
372372
scale,
373373
)
@@ -376,55 +376,55 @@ def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
376376
assert self.ctx is not None
377377
llama_cpp.llama_sample_softmax(
378378
self.ctx,
379-
ctypes.byref(candidates.candidates), # type: ignore
379+
llama_cpp.byref(candidates.candidates),
380380
)
381381

382382
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
383383
assert self.ctx is not None
384384
llama_cpp.llama_sample_top_k(
385-
self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore
385+
self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
386386
)
387387

388388
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
389389
assert self.ctx is not None
390390
llama_cpp.llama_sample_top_p(
391-
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
391+
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
392392
)
393393

394394
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
395395
assert self.ctx is not None
396396
llama_cpp.llama_sample_min_p(
397-
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
397+
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
398398
)
399399

400400
def sample_tail_free(
401401
self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int
402402
):
403403
assert self.ctx is not None
404404
llama_cpp.llama_sample_tail_free(
405-
self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore
405+
self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep
406406
)
407407

408408
def sample_typical(
409409
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
410410
):
411411
assert self.ctx is not None
412412
llama_cpp.llama_sample_typical(
413-
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
413+
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
414414
)
415415

416416
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
417417
assert self.ctx is not None
418418
llama_cpp.llama_sample_temp(
419-
self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore
419+
self.ctx, llama_cpp.byref(candidates.candidates), temp
420420
)
421421

422422
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
423423
assert self.ctx is not None
424424
assert grammar.grammar is not None
425425
llama_cpp.llama_sample_grammar(
426426
self.ctx,
427-
ctypes.byref(candidates.candidates), # type: ignore
427+
llama_cpp.byref(candidates.candidates),
428428
grammar.grammar,
429429
)
430430

@@ -434,25 +434,25 @@ def sample_token_mirostat(
434434
tau: float,
435435
eta: float,
436436
m: int,
437-
mu: ctypes._Pointer[ctypes.c_float], # type: ignore
437+
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
438438
) -> int:
439439
assert self.ctx is not None
440440
return llama_cpp.llama_sample_token_mirostat(
441441
self.ctx,
442-
ctypes.byref(candidates.candidates), # type: ignore
442+
llama_cpp.byref(candidates.candidates),
443443
tau,
444444
eta,
445445
m,
446446
mu,
447447
)
448448

449449
def sample_token_mirostat_v2(
450-
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: ctypes._Pointer[ctypes.c_float] # type: ignore
450+
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float]
451451
) -> int:
452452
assert self.ctx is not None
453453
return llama_cpp.llama_sample_token_mirostat_v2(
454454
self.ctx,
455-
ctypes.byref(candidates.candidates), # type: ignore
455+
llama_cpp.byref(candidates.candidates),
456456
tau,
457457
eta,
458458
mu,
@@ -462,14 +462,14 @@ def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
462462
assert self.ctx is not None
463463
return llama_cpp.llama_sample_token_greedy(
464464
self.ctx,
465-
ctypes.byref(candidates.candidates), # type: ignore
465+
llama_cpp.byref(candidates.candidates),
466466
)
467467

468468
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
469469
assert self.ctx is not None
470470
return llama_cpp.llama_sample_token(
471471
self.ctx,
472-
ctypes.byref(candidates.candidates), # type: ignore
472+
llama_cpp.byref(candidates.candidates),
473473
)
474474

475475
# Grammar
@@ -566,7 +566,7 @@ def __init__(self, *, n_vocab: int):
566566
size=self.n_vocab,
567567
sorted=False,
568568
)
569-
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc)
569+
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
570570
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
571571

572572
def copy_logits(self, logits: npt.NDArray[np.single]):
@@ -754,7 +754,7 @@ def sample(
754754
ctx_main.sample_repetition_penalties(
755755
token_data_array,
756756
# TODO: Only create this once
757-
(llama_cpp.llama_token * len(self.prev))(*self.prev), # type: ignore
757+
(llama_cpp.llama_token * len(self.prev))(*self.prev),
758758
self.params.penalty_last_n,
759759
self.params.penalty_repeat,
760760
self.params.penalty_freq,

0 commit comments

Comments
 (0)
0