|
| 1 | +import ctypes |
| 2 | + |
1 | 3 | import pytest
|
| 4 | + |
2 | 5 | import llama_cpp
|
3 | 6 |
|
4 | 7 | MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama.gguf"
|
@@ -36,19 +39,20 @@ def test_llama_cpp_tokenization():
|
36 | 39 |
|
37 | 40 |
|
38 | 41 | def test_llama_patch(monkeypatch):
|
39 |
| - llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) |
| 42 | + n_ctx = 128 |
| 43 | + llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx) |
40 | 44 | n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
|
| 45 | + assert n_vocab == 32000 |
41 | 46 |
|
42 | 47 | ## Set up mock function
|
43 |
| - def mock_eval(*args, **kwargs): |
| 48 | + def mock_decode(*args, **kwargs): |
44 | 49 | return 0
|
45 | 50 |
|
46 | 51 | def mock_get_logits(*args, **kwargs):
|
47 |
| - return (llama_cpp.c_float * n_vocab)( |
48 |
| - *[llama_cpp.c_float(0) for _ in range(n_vocab)] |
49 |
| - ) |
| 52 | + size = n_vocab * n_ctx |
| 53 | + return (llama_cpp.c_float * size)() |
50 | 54 |
|
51 |
| - monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval) |
| 55 | + monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode) |
52 | 56 | monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
53 | 57 |
|
54 | 58 | output_text = " jumps over the lazy dog."
|
@@ -126,19 +130,19 @@ def test_llama_pickle():
|
126 | 130 |
|
127 | 131 |
|
128 | 132 | def test_utf8(monkeypatch):
|
129 |
| - llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) |
| 133 | + n_ctx = 512 |
| 134 | + llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx, logits_all=True) |
130 | 135 | n_vocab = llama.n_vocab()
|
131 | 136 |
|
132 | 137 | ## Set up mock function
|
133 |
| - def mock_eval(*args, **kwargs): |
| 138 | + def mock_decode(*args, **kwargs): |
134 | 139 | return 0
|
135 | 140 |
|
136 | 141 | def mock_get_logits(*args, **kwargs):
|
137 |
| - return (llama_cpp.c_float * n_vocab)( |
138 |
| - *[llama_cpp.c_float(0) for _ in range(n_vocab)] |
139 |
| - ) |
| 142 | + size = n_vocab * n_ctx |
| 143 | + return (llama_cpp.c_float * size)() |
140 | 144 |
|
141 |
| - monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval) |
| 145 | + monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode) |
142 | 146 | monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
143 | 147 |
|
144 | 148 | output_text = "😀"
|
|
0 commit comments