|
1 | 1 | import ctypes
|
| 2 | +import multiprocessing |
2 | 3 |
|
3 | 4 | import numpy as np
|
4 |
| -import pytest |
5 | 5 | from scipy.special import log_softmax
|
6 | 6 |
|
| 7 | +from huggingface_hub import hf_hub_download |
| 8 | + |
| 9 | +import pytest |
| 10 | + |
7 | 11 | import llama_cpp
|
| 12 | +import llama_cpp._internals as internals |
| 13 | + |
8 | 14 |
|
9 | 15 | MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf"
|
10 | 16 |
|
@@ -46,248 +52,118 @@ def test_llama_cpp_tokenization():
|
46 | 52 | assert text == llama.detokenize(tokens)
|
47 | 53 |
|
48 | 54 |
|
49 |
| -@pytest.fixture |
50 |
| -def mock_llama(monkeypatch): |
51 |
| - def setup_mock(llama: llama_cpp.Llama, output_text: str): |
52 |
| - n_ctx = llama.n_ctx() |
53 |
| - n_vocab = llama.n_vocab() |
54 |
| - output_tokens = llama.tokenize( |
55 |
| - output_text.encode("utf-8"), add_bos=True, special=True |
56 |
| - ) |
57 |
| - logits = (ctypes.c_float * (n_vocab * n_ctx))() |
58 |
| - for i in range(n_ctx): |
59 |
| - output_idx = i + 1 # logits for first tokens predict second token |
60 |
| - if output_idx < len(output_tokens): |
61 |
| - logits[i * n_vocab + output_tokens[output_idx]] = 100.0 |
62 |
| - else: |
63 |
| - logits[i * n_vocab + llama.token_eos()] = 100.0 |
64 |
| - n = 0 |
65 |
| - last_n_tokens = 0 |
66 |
| - |
67 |
| - def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch): |
68 |
| - # Test some basic invariants of this mocking technique |
69 |
| - assert ctx == llama._ctx.ctx, "context does not match mock_llama" |
70 |
| - assert batch.n_tokens > 0, "no tokens in batch" |
71 |
| - assert all( |
72 |
| - batch.n_seq_id[i] == 1 for i in range(batch.n_tokens) |
73 |
| - ), "n_seq >1 not supported by mock_llama" |
74 |
| - assert all( |
75 |
| - batch.seq_id[i][0] == 0 for i in range(batch.n_tokens) |
76 |
| - ), "n_seq >1 not supported by mock_llama" |
77 |
| - assert batch.logits[ |
78 |
| - batch.n_tokens - 1 |
79 |
| - ], "logits not allocated for last token" |
80 |
| - # Update the mock context state |
81 |
| - nonlocal n |
82 |
| - nonlocal last_n_tokens |
83 |
| - n = max(batch.pos[i] for i in range(batch.n_tokens)) + 1 |
84 |
| - last_n_tokens = batch.n_tokens |
85 |
| - return 0 |
86 |
| - |
87 |
| - def mock_get_logits(ctx: llama_cpp.llama_context_p): |
88 |
| - # Test some basic invariants of this mocking technique |
89 |
| - assert ctx == llama._ctx.ctx, "context does not match mock_llama" |
90 |
| - assert n > 0, "mock_llama_decode not called" |
91 |
| - assert last_n_tokens > 0, "mock_llama_decode not called" |
92 |
| - # Return view of logits for last_n_tokens |
93 |
| - return (ctypes.c_float * (last_n_tokens * n_vocab)).from_address( |
94 |
| - ctypes.addressof(logits) |
95 |
| - + (n - last_n_tokens) * n_vocab * ctypes.sizeof(ctypes.c_float) |
96 |
| - ) |
97 |
| - |
98 |
| - monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode) |
99 |
| - monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) |
100 |
| - |
101 |
| - def mock_kv_cache_clear(ctx: llama_cpp.llama_context_p): |
102 |
| - # Test some basic invariants of this mocking technique |
103 |
| - assert ctx == llama._ctx.ctx, "context does not match mock_llama" |
104 |
| - return |
105 |
| - |
106 |
| - def mock_kv_cache_seq_rm( |
107 |
| - ctx: llama_cpp.llama_context_p, |
108 |
| - seq_id: llama_cpp.llama_seq_id, |
109 |
| - pos0: llama_cpp.llama_pos, |
110 |
| - pos1: llama_cpp.llama_pos, |
111 |
| - ): |
112 |
| - # Test some basic invariants of this mocking technique |
113 |
| - assert ctx == llama._ctx.ctx, "context does not match mock_llama" |
114 |
| - return |
115 |
| - |
116 |
| - def mock_kv_cache_seq_cp( |
117 |
| - ctx: llama_cpp.llama_context_p, |
118 |
| - seq_id_src: llama_cpp.llama_seq_id, |
119 |
| - seq_id_dst: llama_cpp.llama_seq_id, |
120 |
| - pos0: llama_cpp.llama_pos, |
121 |
| - pos1: llama_cpp.llama_pos, |
122 |
| - ): |
123 |
| - # Test some basic invariants of this mocking technique |
124 |
| - assert ctx == llama._ctx.ctx, "context does not match mock_llama" |
125 |
| - return |
126 |
| - |
127 |
| - def mock_kv_cache_seq_keep( |
128 |
| - ctx: llama_cpp.llama_context_p, |
129 |
| - seq_id: llama_cpp.llama_seq_id, |
130 |
| - ): |
131 |
| - # Test some basic invariants of this mocking technique |
132 |
| - assert ctx == llama._ctx.ctx, "context does not match mock_llama" |
133 |
| - return |
134 |
| - |
135 |
| - def mock_kv_cache_seq_add( |
136 |
| - ctx: llama_cpp.llama_context_p, |
137 |
| - seq_id: llama_cpp.llama_seq_id, |
138 |
| - pos0: llama_cpp.llama_pos, |
139 |
| - pos1: llama_cpp.llama_pos, |
140 |
| - ): |
141 |
| - # Test some basic invariants of this mocking technique |
142 |
| - assert ctx == llama._ctx.ctx, "context does not match mock_llama" |
143 |
| - return |
144 |
| - |
145 |
| - monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_clear", mock_kv_cache_clear) |
146 |
| - monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_rm", mock_kv_cache_seq_rm) |
147 |
| - monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_cp", mock_kv_cache_seq_cp) |
148 |
| - monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_keep", mock_kv_cache_seq_keep) |
149 |
| - monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_add", mock_kv_cache_seq_add) |
150 |
| - |
151 |
| - return setup_mock |
152 |
| - |
153 |
| - |
154 |
| -# def test_llama_patch(mock_llama): |
155 |
| -# n_ctx = 128 |
156 |
| -# llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx) |
157 |
| -# n_vocab = llama_cpp.llama_n_vocab(llama._model.model) |
158 |
| -# assert n_vocab == 32000 |
159 |
| -# |
160 |
| -# text = "The quick brown fox" |
161 |
| -# output_text = " jumps over the lazy dog." |
162 |
| -# all_text = text + output_text |
163 |
| -# |
164 |
| -# ## Test basic completion from bos until eos |
165 |
| -# mock_llama(llama, all_text) |
166 |
| -# completion = llama.create_completion("", max_tokens=36) |
167 |
| -# assert completion["choices"][0]["text"] == all_text |
168 |
| -# assert completion["choices"][0]["finish_reason"] == "stop" |
169 |
| -# |
170 |
| -# ## Test basic completion until eos |
171 |
| -# mock_llama(llama, all_text) |
172 |
| -# completion = llama.create_completion(text, max_tokens=20) |
173 |
| -# assert completion["choices"][0]["text"] == output_text |
174 |
| -# assert completion["choices"][0]["finish_reason"] == "stop" |
175 |
| -# |
176 |
| -# ## Test streaming completion until eos |
177 |
| -# mock_llama(llama, all_text) |
178 |
| -# chunks = list(llama.create_completion(text, max_tokens=20, stream=True)) |
179 |
| -# assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text |
180 |
| -# assert chunks[-1]["choices"][0]["finish_reason"] == "stop" |
181 |
| -# |
182 |
| -# ## Test basic completion until stop sequence |
183 |
| -# mock_llama(llama, all_text) |
184 |
| -# completion = llama.create_completion(text, max_tokens=20, stop=["lazy"]) |
185 |
| -# assert completion["choices"][0]["text"] == " jumps over the " |
186 |
| -# assert completion["choices"][0]["finish_reason"] == "stop" |
187 |
| -# |
188 |
| -# ## Test streaming completion until stop sequence |
189 |
| -# mock_llama(llama, all_text) |
190 |
| -# chunks = list( |
191 |
| -# llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"]) |
192 |
| -# ) |
193 |
| -# assert ( |
194 |
| -# "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the " |
195 |
| -# ) |
196 |
| -# assert chunks[-1]["choices"][0]["finish_reason"] == "stop" |
197 |
| -# |
198 |
| -# ## Test basic completion until length |
199 |
| -# mock_llama(llama, all_text) |
200 |
| -# completion = llama.create_completion(text, max_tokens=2) |
201 |
| -# assert completion["choices"][0]["text"] == " jumps" |
202 |
| -# assert completion["choices"][0]["finish_reason"] == "length" |
203 |
| -# |
204 |
| -# ## Test streaming completion until length |
205 |
| -# mock_llama(llama, all_text) |
206 |
| -# chunks = list(llama.create_completion(text, max_tokens=2, stream=True)) |
207 |
| -# assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps" |
208 |
| -# assert chunks[-1]["choices"][0]["finish_reason"] == "length" |
209 |
| - |
210 |
| - |
211 |
| -def test_llama_pickle(): |
212 |
| - import pickle |
213 |
| - import tempfile |
214 |
| - |
215 |
| - fp = tempfile.TemporaryFile() |
216 |
| - llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) |
217 |
| - pickle.dump(llama, fp) |
218 |
| - fp.seek(0) |
219 |
| - llama = pickle.load(fp) |
220 |
| - |
221 |
| - assert llama |
222 |
| - assert llama.ctx is not None |
223 |
| - |
224 |
| - text = b"Hello World" |
225 |
| - |
226 |
| - assert llama.detokenize(llama.tokenize(text)) == text |
| 55 | +def test_llama_cpp_version(): |
| 56 | + assert llama_cpp.__version__ |
227 | 57 |
|
228 | 58 |
|
229 |
| -# def test_utf8(mock_llama): |
230 |
| -# llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True) |
231 |
| -# |
232 |
| -# output_text = "😀" |
233 |
| -# |
234 |
| -# ## Test basic completion with utf8 multibyte |
235 |
| -# mock_llama(llama, output_text) |
236 |
| -# completion = llama.create_completion("", max_tokens=4) |
237 |
| -# assert completion["choices"][0]["text"] == output_text |
238 |
| -# |
239 |
| -# ## Test basic completion with incomplete utf8 multibyte |
240 |
| -# mock_llama(llama, output_text) |
241 |
| -# completion = llama.create_completion("", max_tokens=1) |
242 |
| -# assert completion["choices"][0]["text"] == "" |
| 59 | +@pytest.fixture |
| 60 | +def llama_cpp_model_path(): |
| 61 | + repo_id = "Qwen/Qwen2-0.5B-Instruct-GGUF" |
| 62 | + filename = "qwen2-0_5b-instruct-q8_0.gguf" |
| 63 | + model_path = hf_hub_download(repo_id, filename) |
| 64 | + return model_path |
| 65 | + |
| 66 | + |
| 67 | +def test_real_model(llama_cpp_model_path): |
| 68 | + import os |
| 69 | + assert os.path.exists(llama_cpp_model_path) |
| 70 | + |
| 71 | + params = llama_cpp.llama_model_default_params() |
| 72 | + params.use_mmap = llama_cpp.llama_supports_mmap() |
| 73 | + params.use_mlock = llama_cpp.llama_supports_mlock() |
| 74 | + params.check_tensors = False |
| 75 | + |
| 76 | + model = internals.LlamaModel(path_model=llama_cpp_model_path, params=params) |
| 77 | + |
| 78 | + cparams = llama_cpp.llama_context_default_params() |
| 79 | + cparams.n_ctx = 16 |
| 80 | + cparams.n_batch = 16 |
| 81 | + cparams.n_ubatch = 16 |
| 82 | + cparams.n_threads = multiprocessing.cpu_count() |
| 83 | + cparams.n_threads_batch = multiprocessing.cpu_count() |
| 84 | + cparams.logits_all = False |
| 85 | + cparams.flash_attn = True |
| 86 | + |
| 87 | + context = internals.LlamaContext(model=model, params=cparams) |
| 88 | + tokens = model.tokenize(b"Hello, world!", add_bos=True, special=True) |
| 89 | + |
| 90 | + assert tokens == [9707, 11, 1879, 0] |
| 91 | + |
| 92 | + tokens = model.tokenize(b"The quick brown fox jumps", add_bos=True, special=True) |
| 93 | + |
| 94 | + batch = internals.LlamaBatch(n_tokens=len(tokens), embd=0, n_seq_max=1) |
| 95 | + |
| 96 | + seed = 1337 |
| 97 | + sampler = internals.LlamaSampler() |
| 98 | + sampler.add_top_k(50) |
| 99 | + sampler.add_top_p(0.9, 1) |
| 100 | + sampler.add_temp(0.8) |
| 101 | + sampler.add_dist(seed) |
| 102 | + |
| 103 | + result = tokens |
| 104 | + n_eval = 0 |
| 105 | + for _ in range(4): |
| 106 | + batch.set_batch(tokens, n_past=n_eval, logits_all=False) |
| 107 | + context.decode(batch) |
| 108 | + n_eval += len(tokens) |
| 109 | + token_id = sampler.sample(context, -1) |
| 110 | + tokens = [token_id] |
| 111 | + result += tokens |
| 112 | + |
| 113 | + output = result[5:] |
| 114 | + output_text = model.detokenize(output, special=True) |
| 115 | + assert output_text == b" over the lazy dog" |
| 116 | + |
| 117 | +def test_real_llama(llama_cpp_model_path): |
| 118 | + model = llama_cpp.Llama( |
| 119 | + llama_cpp_model_path, |
| 120 | + n_ctx=32, |
| 121 | + n_batch=32, |
| 122 | + n_ubatch=32, |
| 123 | + n_threads=multiprocessing.cpu_count(), |
| 124 | + n_threads_batch=multiprocessing.cpu_count(), |
| 125 | + logits_all=False, |
| 126 | + flash_attn=True, |
| 127 | + ) |
243 | 128 |
|
| 129 | + output = model.create_completion( |
| 130 | + "The quick brown fox jumps", |
| 131 | + max_tokens=4, |
| 132 | + top_k=50, |
| 133 | + top_p=0.9, |
| 134 | + temperature=0.8, |
| 135 | + seed=1337 |
| 136 | + ) |
| 137 | + assert output["choices"][0]["text"] == " over the lazy dog" |
| 138 | + |
| 139 | + |
| 140 | + output = model.create_completion( |
| 141 | + "The capital of france is paris, 'true' or 'false'?:\n", |
| 142 | + max_tokens=4, |
| 143 | + top_k=50, |
| 144 | + top_p=0.9, |
| 145 | + temperature=0.8, |
| 146 | + seed=1337, |
| 147 | + grammar=llama_cpp.LlamaGrammar.from_string(""" |
| 148 | +root ::= "true" | "false" |
| 149 | +""") |
| 150 | + ) |
| 151 | + assert output["choices"][0]["text"] == "true" |
244 | 152 |
|
245 |
| -def test_llama_server(): |
246 |
| - from fastapi.testclient import TestClient |
247 |
| - from llama_cpp.server.app import create_app, Settings |
| 153 | + def logit_processor_func(input_ids, logits): |
| 154 | + return logits * 1 |
248 | 155 |
|
249 |
| - settings = Settings( |
250 |
| - model=MODEL, |
251 |
| - vocab_only=True, |
| 156 | + logit_processors = llama_cpp.LogitsProcessorList( |
| 157 | + [logit_processor_func] |
252 | 158 | )
|
253 |
| - app = create_app(settings) |
254 |
| - client = TestClient(app) |
255 |
| - response = client.get("/v1/models") |
256 |
| - assert response.json() == { |
257 |
| - "object": "list", |
258 |
| - "data": [ |
259 |
| - { |
260 |
| - "id": MODEL, |
261 |
| - "object": "model", |
262 |
| - "owned_by": "me", |
263 |
| - "permissions": [], |
264 |
| - } |
265 |
| - ], |
266 |
| - } |
267 |
| - |
268 |
| - |
269 |
| -@pytest.mark.parametrize( |
270 |
| - "size_and_axis", |
271 |
| - [ |
272 |
| - ((32_000,), -1), # last token's next-token logits |
273 |
| - ((10, 32_000), -1), # many tokens' next-token logits, or batch of last tokens |
274 |
| - ((4, 10, 32_000), -1), # batch of texts |
275 |
| - ], |
276 |
| -) |
277 |
| -@pytest.mark.parametrize("convert_to_list", [True, False]) |
278 |
| -def test_logits_to_logprobs(size_and_axis, convert_to_list: bool, atol: float = 1e-7): |
279 |
| - size, axis = size_and_axis |
280 |
| - logits: np.ndarray = -np.random.uniform(low=0, high=60, size=size) |
281 |
| - logits = logits.astype(np.single) |
282 |
| - if convert_to_list: |
283 |
| - # Currently, logits are converted from arrays to lists. This may change soon |
284 |
| - logits = logits.tolist() |
285 |
| - log_probs = llama_cpp.Llama.logits_to_logprobs(logits, axis=axis) |
286 |
| - log_probs_correct = log_softmax(logits, axis=axis) |
287 |
| - assert log_probs.dtype == np.single |
288 |
| - assert log_probs.shape == size |
289 |
| - assert np.allclose(log_probs, log_probs_correct, atol=atol) |
290 |
| - |
291 | 159 |
|
292 |
| -def test_llama_cpp_version(): |
293 |
| - assert llama_cpp.__version__ |
| 160 | + output = model.create_completion( |
| 161 | + "The capital of france is par", |
| 162 | + max_tokens=4, |
| 163 | + top_k=50, |
| 164 | + top_p=0.9, |
| 165 | + temperature=0.8, |
| 166 | + seed=1337, |
| 167 | + logits_processor=logit_processors |
| 168 | + ) |
| 169 | + assert output["choices"][0]["text"].lower().startswith("is") |
0 commit comments