8000 Use mock_llama for all tests · Nagibiku/llama-cpp-python@d7388f1 · GitHub
[go: up one dir, main page]

Skip to content

Commit d7388f1

Browse files
committed
Use mock_llama for all tests
1 parent dbfaf53 commit d7388f1

File tree

1 file changed

+3
-40
lines changed

1 file changed

+3
-40
lines changed

tests/test_llama.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -160,55 +160,18 @@ def test_llama_pickle():
160160
assert llama.detokenize(llama.tokenize(text)) == text
161161

162162

163-
def test_utf8(mock_llama, monkeypatch):
163+
def test_utf8(mock_llama):
164164
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True)
165-
n_ctx = llama.n_ctx()
166-
n_vocab = llama.n_vocab()
167165

168166
output_text = "😀"
169-
output_tokens = llama.tokenize(
170-
output_text.encode("utf-8"), add_bos=True, special=True
171-
)
172-
token_eos = llama.token_eos()
173-
n = 0
174-
175-
def reset():
176-
nonlocal n
177-
llama.reset()
178-
n = 0
179-
180-
## Set up mock function
181-
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
182-
nonlocal n
183-
assert batch.n_tokens > 0
184-
assert llama.n_tokens == n
185-
n += batch.n_tokens
186-
return 0
187-
188-
def mock_get_logits(*args, **kwargs):
189-
size = n_vocab * n_ctx
190-
return (llama_cpp.c_float * size)()
191-
192-
def mock_sample(*args, **kwargs):
193-
nonlocal n
194-
if n <= len(output_tokens):
195-
return output_tokens[n - 1]
196-
else:
197-
return token_eos
198-
199-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
200-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
201-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
202167

203168
## Test basic completion with utf8 multibyte
204-
# mock_llama(llama, output_text)
205-
reset()
169+
mock_llama(llama, output_text)
206170
completion = llama.create_completion("", max_tokens=4)
207171
assert completion["choices"][0]["text"] == output_text
208172

209173
## Test basic completion with incomplete utf8 multibyte
210-
# mock_llama(llama, output_text)
211-
reset()
174+
mock_llama(llama, output_text)
212175
completion = llama.create_completion("", max_tokens=1)
213176
assert completion["choices"][0]["text"] == ""
214177

0 commit comments

Comments
 (0)
0