8000 Fix tests · jithinraj/llama-cpp-python@e32ecb0 · GitHub
[go: up one dir, main page]

Skip to content

Commit e32ecb0

Browse files
committed
Fix tests
1 parent 6f0b0b1 commit e32ecb0

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

tests/test_llama.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import ctypes
2+
13
import pytest
4+
25
import llama_cpp
36

47
MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama.gguf"
@@ -36,19 +39,20 @@ def test_llama_cpp_tokenization():
3639

3740

3841
def test_llama_patch(monkeypatch):
39-
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
42+
n_ctx = 128
43+
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx)
4044
n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
45+
assert n_vocab == 32000
4146

4247
## Set up mock function
43-
def mock_eval(*args, **kwargs):
48+
def mock_decode(*args, **kwargs):
4449
return 0
4550

4651
def mock_get_logits(*args, **kwargs):
47-
return (llama_cpp.c_float * n_vocab)(
48-
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
49-
)
52+
size = n_vocab * n_ctx
53+
return (llama_cpp.c_float * size)()
5054

51-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval)
55+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
5256
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
5357

5458
output_text = " jumps over the lazy dog."
@@ -126,19 +130,19 @@ def test_llama_pickle():
126130

127131

128132
def test_utf8(monkeypatch):
129-
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
133+
n_ctx = 512
134+
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx, logits_all=True)
130135
n_vocab = llama.n_vocab()
131136

132137
## Set up mock function
133-
def mock_eval(*args, **kwargs):
138+
def mock_decode(*args, **kwargs):
134139
return 0
135140

136141
def mock_get_logits(*args, **kwargs):
137-
return (llama_cpp.c_float * n_vocab)(
138-
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
139-
)
142+
size = n_vocab * n_ctx
143+
return (llama_cpp.c_float * size)()
140144

141-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval)
145+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
142146
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
143147

144148
output_text = "😀"

0 commit comments

Comments
 (0)
0