8000 Add actual logit processor test · surtweig/llama-cpp-python@334ca35 · GitHub
[go: up one dir, main page]

Skip to content

Commit 334ca35

Browse files
committed
Add actual logit processor test
1 parent 2102eb9 commit 334ca35

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

tests/test_llama.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf"
1616

1717

18+
def test_llama_cpp_version():
19+
assert llama_cpp.__version__
20+
21+
1822
def test_llama_cpp_tokenization():
1923
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, verbose=False)
2024

@@ -52,10 +56,6 @@ def test_llama_cpp_tokenization():
5256
assert text == llama.detokenize(tokens)
5357

5458

55-
def test_llama_cpp_version():
56-
assert llama_cpp.__version__
57-
58-
5959
@pytest.fixture
6060
def llama_cpp_model_path():
6161
repo_id = "Qwen/Qwen2-0.5B-Instruct-GGUF"
@@ -150,8 +150,12 @@ def test_real_llama(llama_cpp_model_path):
150150
)
151151
assert output["choices"][0]["text"] == "true"
152152

153+
suffix = b"rot"
154+
tokens = model.tokenize(suffix, add_bos=True, special=True)
153155
def logit_processor_func(input_ids, logits):
154-
return logits * 1
156+
for token in tokens:
157+
logits[token] *= 1000
158+
return logits
155159

156160
logit_processors = llama_cpp.LogitsProcessorList(
157161
[logit_processor_func]
@@ -166,4 +170,4 @@ def logit_processor_func(input_ids, logits):
166170
seed=1337,
167171
logits_processor=logit_processors
168172
)
169-
assert output["choices"][0]["text"].lower().startswith("is")
173+
assert output["choices"][0]["text"].lower().startswith("rot")

0 commit comments

Comments
 (0)
0