@@ -47,7 +47,6 @@ def test_llama_cpp_tokenization():
47
47
@pytest .fixture
48
48
def mock_llama (monkeypatch ):
49
49
def setup_mock (llama : llama_cpp .Llama , output_text : str ):
50
- llama .reset ()
51
50
n_vocab = llama .n_vocab ()
52
51
output_tokens = llama .tokenize (
53
52
output_text .encode ("utf-8" ), add_bos = True , special = True
@@ -59,28 +58,41 @@ def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
59
58
nonlocal n
60
59
nonlocal last_n_tokens
61
60
# Test some basic invariants of this mocking technique
62
- assert ctx == llama ._ctx .ctx
63
- assert llama .n_tokens == n
64
- assert batch .n_tokens > 0
65
- n += batch .n_tokens
61
+ assert ctx == llama ._ctx .ctx , "context does not match mock_llama"
62
+ assert batch .n_tokens > 0 , "no tokens in batch"
63
+ assert all (
64
+ batch .n_seq_id [i ] == 1 for i in range (batch .n_tokens )
65
+ ), "n_seq >1 not supported by mock_llama"
66
+ assert all (
67
+ batch .seq_id [i ][0 ] == 0 for i in range (batch .n_tokens )
68
+ ), "n_seq >1 not supported by mock_llama"
69
+ assert batch .logits [
70
+ batch .n_tokens - 1
71
+ ], "logits not allocated for last token"
72
+ # Update the mock context state
73
+ n = max (batch .pos [i ] for i in range (batch .n_tokens )) + 1
66
74
last_n_tokens = batch .n_tokens
67
75
return 0
68
76
69
77
def mock_get_logits (* args , ** kwargs ):
70
- nonlocal last_n_tokens
71
- size = n_vocab * last_n_tokens
72
- return (llama_cpp .c_float * size )()
73
-
74
- def mock_sample (* args , ** kwargs ):
75
78
nonlocal n
76
- if n < len (output_tokens ):
77
- return output_tokens [n ]
78
- else :
79
- return llama .token_eos ()
79
+ nonlocal last_n_tokens
80
+ assert n > 0 , "mock_llama_decode not called"
81
+ assert last_n_tokens > 0 , "mock_llama_decode not called"
82
+ logits = (llama_cpp .c_float * (last_n_tokens * n_vocab ))(- 100.0 )
83
+ for logits_idx , output_idx in enumerate (
84
+ range (n - last_n_tokens + 1 , n + 1 )
85
+ ):
86
+ if output_idx < len (output_tokens ):
87
+ logits [
88
+ logits_idx * last_n_tokens + output_tokens [output_idx ]
89
+ ] = 100.0
90
+ else :
91
+ logits [logits_idx * last_n_tokens + llama .token_eos ()] = 100.0
92
+ return logits
80
93
81
94
monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
82
95
monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
83
- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token" , mock_sample )
84
96
85
97
return setup_mock
86
98
0 commit comments