8000 Use real model for tests · surtweig/llama-cpp-python@2897fc2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2897fc2

Browse files
committed
Use real model for tests
1 parent 56c1f45 commit 2897fc2

File tree

1 file changed

+115
-239
lines changed

1 file changed

+115
-239
lines changed

tests/test_llama.py

Lines changed: 115 additions & 239 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import ctypes
2+
import multiprocessing
23

34
import numpy as np
4-
import pytest
55
from scipy.special import log_softmax
66

7+
from huggingface_hub import hf_hub_download
8+
9+
import pytest
10+
711
import llama_cpp
12+
import llama_cpp._internals as internals
13+
814

915
MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf"
1016

@@ -46,248 +52,118 @@ def test_llama_cpp_tokenization():
4652
assert text == llama.detokenize(tokens)
4753

4854

49-
@pytest.fixture
50-
def mock_llama(monkeypatch):
51-
def setup_mock(llama: llama_cpp.Llama, output_text: str):
52-
n_ctx = llama.n_ctx()
53-
n_vocab = llama.n_vocab()
54-
output_tokens = llama.tokenize(
55-
output_text.encode("utf-8"), add_bos=True, special=True
56-
)
57-
logits = (ctypes.c_float * (n_vocab * n_ctx))()
58-
for i in range(n_ctx):
59-
output_idx = i + 1 # logits for first tokens predict second token
60-
if output_idx < len(output_tokens):
61-
logits[i * n_vocab + output_tokens[output_idx]] = 100.0
62-
else:
63-
logits[i * n_vocab + llama.token_eos()] = 100.0
64-
n = 0
65-
last_n_tokens = 0
66-
67-
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
68-
# Test some basic invariants of this mocking technique
69-
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
70-
assert batch.n_tokens > 0, "no tokens in batch"
71-
assert all(
72-
batch.n_seq_id[i] == 1 for i in range(batch.n_tokens)
73-
), "n_seq >1 not supported by mock_llama"
74-
assert all(
75-
batch.seq_id[i][0] == 0 for i in range(batch.n_tokens)
76-
), "n_seq >1 not supported by mock_llama"
77-
assert batch.logits[
78-
batch.n_tokens - 1
79-
], "logits not allocated for last token"
80-
# Update the mock context state
81-
nonlocal n
82-
nonlocal last_n_tokens
83-
n = max(batch.pos[i] for i in range(batch.n_tokens)) + 1
84-
last_n_tokens = batch.n_tokens
85-
return 0
86-
87-
def mock_get_logits(ctx: llama_cpp.llama_context_p):
88-
# Test some basic invariants of this mocking technique
89-
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
90-
assert n > 0, "mock_llama_decode not called"
91-
assert last_n_tokens > 0, "mock_llama_decode not called"
92-
# Return view of logits for last_n_tokens
93-
return (ctypes.c_float * (last_n_tokens * n_vocab)).from_address(
94-
ctypes.addressof(logits)
95-
+ (n - last_n_tokens) * n_vocab * ctypes.sizeof(ctypes.c_float)
96-
)
97-
98-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
99-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
100-
101-
def mock_kv_cache_clear(ctx: llama_cpp.llama_context_p):
102-
# Test some basic invariants of this mocking technique
103-
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
104-
return
105-
106-
def mock_kv_cache_seq_rm(
107-
ctx: llama_cpp.llama_context_p,
108-
seq_id: llama_cpp.llama_seq_id,
109-
pos0: llama_cpp.llama_pos,
110-
pos1: llama_cpp.llama_pos,
111-
):
112-
# Test some basic invariants of this mocking technique
113-
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
114-
return
115-
116-
def mock_kv_cache_seq_cp(
117-
ctx: llama_cpp.llama_context_p,
118-
seq_id_src: llama_cpp.llama_seq_id,
119-
seq_id_dst: llama_cpp.llama_seq_id,
120-
pos0: llama_cpp.llama_pos,
121-
pos1: llama_cpp.llama_pos,
122-
):
123-
# Test some basic invariants of this mocking technique
124-
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
125-
return
126-
127-
def mock_kv_cache_seq_keep(
128-
ctx: llama_cpp.llama_context_p,
129-
seq_id: llama_cpp.llama_seq_id,
130-
):
131-
# Test some basic invariants of this mocking technique
132-
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
133-
return
134-
135-
def mock_kv_cache_seq_add(
136-
ctx: llama_cpp.llama_context_p,
137-
seq_id: llama_cpp.llama_seq_id,
138-
pos0: llama_cpp.llama_pos,
139-
pos1: llama_cpp.llama_pos,
140-
):
141-
# Test some basic invariants of this mocking technique
142-
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
143-
return
144-
145-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_clear", mock_kv_cache_clear)
146-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_rm", mock_kv_cache_seq_rm)
147-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_cp", mock_kv_cache_seq_cp)
148-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_keep", mock_kv_cache_seq_keep)
149-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_add", mock_kv_cache_seq_add)
150-
151-
return setup_mock
152-
153-
154-
# def test_llama_patch(mock_llama):
155-
# n_ctx = 128
156-
# llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx)
157-
# n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
158-
# assert n_vocab == 32000
159-
#
160-
# text = "The quick brown fox"
161-
# output_text = " jumps over the lazy dog."
162-
# all_text = text + output_text
163-
#
164-
# ## Test basic completion from bos until eos
165-
# mock_llama(llama, all_text)
166-
# completion = llama.create_completion("", max_tokens=36)
167-
# assert completion["choices"][0]["text"] == all_text
168-
# assert completion["choices"][0]["finish_reason"] == "stop"
169-
#
170-
# ## Test basic completion until eos
171-
# mock_llama(llama, all_text)
172-
# completion = llama.create_completion(text, max_tokens=20)
173-
# assert completion["choices"][0]["text"] == output_text
174-
# assert completion["choices"][0]["finish_reason"] == "stop"
175-
#
176-
# ## Test streaming completion until eos
177-
# mock_llama(llama, all_text)
178-
# chunks = list(llama.create_completion(text, max_tokens=20, stream=True))
179-
# assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
180-
# assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
181-
#
182-
# ## Test basic completion until stop sequence
183-
# mock_llama(llama, all_text)
184-
# completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
185-
# assert completion["choices"][0]["text"] == " jumps over the "
186-
# assert completion["choices"][0]["finish_reason"] == "stop"
187-
#
188-
# ## Test streaming completion until stop sequence
189-
# mock_llama(llama, all_text)
190-
# chunks = list(
191-
# llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
192-
# )
193-
# assert (
194-
# "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
195-
# )
196-
# assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
197-
#
198-
# ## Test basic completion until length
199-
# mock_llama(llama, all_text)
200-
# completion = llama.create_completion(text, max_tokens=2)
201-
# assert completion["choices"][0]["text"] == " jumps"
202-
# assert completion["choices"][0]["finish_reason"] == "length"
203-
#
204-
# ## Test streaming completion until length
205-
# mock_llama(llama, all_text)
206-
# chunks = list(llama.create_completion(text, max_tokens=2, stream=True))
207-
# assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps"
208-
# assert chunks[-1]["choices"][0]["finish_reason"] == "length"
209-
210-
211-
def test_llama_pickle():
212-
import pickle
213-
import tempfile
214-
215-
fp = tempfile.TemporaryFile()
216-
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
217-
pickle.dump(llama, fp)
218-
fp.seek(0)
219-
llama = pickle.load(fp)
220-
221-
assert llama
222-
assert llama.ctx is not None
223-
224-
text = b"Hello World"
225-
226-
assert llama.detokenize(llama.tokenize(text)) == text
55+
def test_llama_cpp_version():
56+
assert llama_cpp.__version__
22757

22858

229-
# def test_utf8(mock_llama):
230-
# llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True)
231-
#
232-
# output_text = "😀"
233-
#
234-
# ## Test basic completion with utf8 multibyte
235-
# mock_llama(llama, output_text)
236-
# completion = llama.create_completion("", max_tokens=4)
237-
# assert completion["choices"][0]["text"] == output_text
238-
#
239-
# ## Test basic completion with incomplete utf8 multibyte
240-
# mock_llama(llama, output_text)
241-
# completion = llama.create_completion("", max_tokens=1)
242-
# assert completion["choices"][0]["text"] == ""
59+
@pytest.fixture
60+
def llama_cpp_model_path():
61+
repo_id = "Qwen/Qwen2-0.5B-Instruct-GGUF"
62+
filename = "qwen2-0_5b-instruct-q8_0.gguf"
63+
model_path = hf_hub_download(repo_id, filename)
64+
return model_path
65+
66+
67+
def test_real_model(llama_cpp_model_path):
68+
import os
69+
assert os.path.exists(llama_cpp_model_path)
70+
71+
params = llama_cpp.llama_model_default_params()
72+
params.use_mmap = llama_cpp.llama_supports_mmap()
73+
params.use_mlock = llama_cpp.llama_supports_mlock()
74+
params.check_tensors = False
75+
76+
model = internals.LlamaModel(path_model=llama_cpp_model_path, params=params)
77+
78+
cparams = llama_cpp.llama_context_default_params()
79+
cparams.n_ctx = 16
80+
cparams.n_batch = 16
81+
cparams.n_ubatch = 16
82+
cparams.n_threads = multiprocessing.cpu_count()
83+
cparams.n_threads_batch = multiprocessing.cpu_count()
84+
cparams.logits_all = False
85+
cparams.flash_attn = True
86+
87+
context = internals.LlamaContext(model=model, params=cparams)
88+
tokens = model.tokenize(b"Hello, world!", add_bos=True, special=True)
89+
90+
assert tokens == [9707, 11, 1879, 0]
91+
92+
tokens = model.tokenize(b"The quick brown fox jumps", add_bos=True, special=True)
93+
94+
batch = internals.LlamaBatch(n_tokens=len(tokens), embd=0, n_seq_max=1)
95+
96+
seed = 1337
97+
sampler = internals.LlamaSampler()
98+
sampler.add_top_k(50)
99+
sampler.add_top_p(0.9, 1)
100+
sampler.add_temp(0.8)
101+
sampler.add_dist(seed)
102+
103+
result = tokens
104+
n_eval = 0
105+
for _ in range(4):
106+
batch.set_batch(tokens, n_past=n_eval, logits_all=False)
107+
context.decode(batch)
108+
n_eval += len(tokens)
109+
token_id = sampler.sample(context, -1)
110+
tokens = [token_id]
111+
result += tokens
112+
113+
output = result[5:]
114+
output_text = model.detokenize(output, special=True)
115+
assert output_text == b" over the lazy dog"
116+
117+
def test_real_llama(llama_cpp_model_path):
118+
model = llama_cpp.Llama(
119+
llama_cpp_model_path,
120+
n_ctx=32,
121+
n_batch=32,
122+
n_ubatch=32,
123+
n_threads=multiprocessing.cpu_count(),
124+
n_threads_batch=multiprocessing.cpu_count(),
125+
logits_all=False,
126+
flash_attn=True,
127+
)
243128

129+
output = model.create_completion(
130+
"The quick brown fox jumps",
131+
max_tokens=4,
132+
top_k=50,
133+
top_p=0.9,
134+
temperature=0.8,
135+
seed=1337
136+
)
137+
assert output["choices"][0]["text"] == " over the lazy dog"
138+
139+
140+
output = model.create_completion(
141+
"The capital of france is paris, 'true' or 'false'?:\n",
142+
max_tokens=4,
143+
top_k=50,
144+
top_p=0.9,
145+
temperature=0.8,
146+
seed=1337,
147+
grammar=llama_cpp.LlamaGrammar.from_string("""
148+
root ::= "true" | "false"
149+
""")
150+
)
151+
assert output["choices"][0]["text"] == "true"
244152

245-
def test_llama_server():
246-
from fastapi.testclient import TestClient
247-
from llama_cpp.server.app import create_app, Settings
153+
def logit_processor_func(input_ids, logits):
154+
return logits * 1
248155

249-
settings = Settings(
250-
model=MODEL,
251-
vocab_only=True,
156+
logit_processors = llama_cpp.LogitsProcessorList(
157+
[logit_processor_func]
252158
)
253-
app = create_app(settings)
254-
client = TestClient(app)
255-
response = client.get("/v1/models")
256-
assert response.json() == {
257-
"object": "list",
258-
"data": [
259-
{
260-
"id": MODEL,
261-
"object": "model",
262-
"owned_by": "me",
263-
"permissions": [],
264-
}
265-
],
266-
}
267-
268-
269-
@pytest.mark.parametrize(
270-
"size_and_axis",
271-
[
272-
((32_000,), -1), # last token's next-token logits
273-
((10, 32_000), -1), # many tokens' next-token logits, or batch of last tokens
274-
((4, 10, 32_000), -1), # batch of texts
275-
],
276-
)
277-
@pytest.mark.parametrize("convert_to_list", [True, False])
278-
def test_logits_to_logprobs(size_and_axis, convert_to_list: bool, atol: float = 1e-7):
279-
size, axis = size_and_axis
280-
logits: np.ndarray = -np.random.uniform(low=0, high=60, size=size)
281-
logits = logits.astype(np.single)
282-
if convert_to_list:
283-
# Currently, logits are converted from arrays to lists. This may change soon
284-
logits = logits.tolist()
285-
log_probs = llama_cpp.Llama.logits_to_logprobs(logits, axis=axis)
286-
log_probs_correct = log_softmax(logits, axis=axis)
287-
assert log_probs.dtype == np.single
288-
assert log_probs.shape == size
289-
assert np.allclose(log_probs, log_probs_correct, atol=atol)
290-
291159

292-
def test_llama_cpp_version():
293-
assert llama_cpp.__version__
160+
output = model.create_completion(
161+
"The capital of france is par",
162+
max_tokens=4,
163+
top_k=50,
164+
top_p=0.9,
165+
temperature=0.8,
166+
seed=1337,
167+
logits_processor=logit_processors
168+
)
169+
assert output["choices"][0]["text"].lower().startswith("is")

0 commit comments

Comments
 (0)
0