8000 tests: don't mock sampling functions · Nagibiku/llama-cpp-python@0a7e05b · GitHub
[go: up one dir, main page]

Skip to content

Commit 0a7e05b

Browse files
committed
tests: don't mock sampling functions
1 parent d7388f1 commit 0a7e05b

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

tests/test_llama.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def test_llama_cpp_tokenization():
4747
@pytest.fixture
4848
def mock_llama(monkeypatch):
4949
def setup_mock(llama: llama_cpp.Llama, output_text: str):
50-
llama.reset()
5150
n_vocab = llama.n_vocab()
5251
output_tokens = llama.tokenize(
5352
output_text.encode("utf-8"), add_bos=True, special=True
@@ -59,28 +58,41 @@ def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
5958
nonlocal n
6059
nonlocal last_n_tokens
6160
# Test some basic invariants of this mocking technique
62-
assert ctx == llama._ctx.ctx
63-
assert llama.n_tokens == n
64-
assert batch.n_tokens > 0
65-
n += batch.n_tokens
61+
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
62+
assert batch.n_tokens > 0, "no tokens in batch"
63+
assert all(
64+
batch.n_seq_id[i] == 1 for i in range(batch.n_tokens)
65+
), "n_seq >1 not supported by mock_llama"
66+
assert all(
67+
batch.seq_id[i][0] == 0 for i in range(batch.n_tokens)
68+
), "n_seq >1 not supported by mock_llama"
69+
assert batch.logits[
70+
batch.n_tokens - 1
71+
], "logits not allocated for last token"
72+
# Update the mock context state
73+
n = max(batch.pos[i] for i in range(batch.n_tokens)) + 1
6674
last_n_tokens = batch.n_tokens
6775
return 0
6876

6977
def mock_get_logits(*args, **kwargs):
70-
nonlocal last_n_tokens
71-
size = n_vocab * last_n_tokens
72-
return (llama_cpp.c_float * size)()
73-
74-
def mock_sample(*args, **kwargs):
7578
nonlocal n
76-
if n < len(output_tokens):
77-
return output_tokens[n]
78-
else:
79-
return llama.token_eos()
79+
nonlocal last_n_tokens
80+
assert n > 0, "mock_llama_decode not called"
81+
assert last_n_tokens > 0, "mock_llama_decode not called"
82+
logits = (llama_cpp.c_float * (last_n_tokens * n_vocab))(-100.0)
83+
for logits_idx, output_idx in enumerate(
84+
range(n - last_n_tokens + 1, n + 1)
85+
):
86+
if output_idx < len(output_tokens):
87+
logits[
88+
logits_idx * last_n_tokens + output_tokens[output_idx]
89+
] = 100.0
90+
else:
91+
logits[logits_idx * last_n_tokens + llama.token_eos()] = 100.0
92+
return logits
8093

8194
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
8295
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
83-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
8496

8597
return setup_mock
8698

0 commit comments

Comments
 (0)
0