3
3
import time
4
4
import multiprocessing
5
5
from typing import List , Optional
6
+ from collections import deque
6
7
7
8
from . import llama_cpp
8
9
@@ -46,9 +47,6 @@ def __init__(
46
47
"""
47
48
self .model_path = model_path
48
49
49
- self .last_n = 64
50
- self .max_chunk_size = 32
51
-
52
50
self .params = llama_cpp .llama_context_default_params ()
53
51
self .params .n_ctx = n_ctx
54
52
self .params .n_parts = n_parts
@@ -59,9 +57,10 @@ def __init__(
59
57
self .params .use_mlock = use_mlock
60
58
self .params .embedding = embedding
61
59
62
- self .n_threads = n_threads or multiprocessing .cpu_count ()
60
+ self .last_n = 64
61
+ self .max_chunk_size = n_ctx
63
62
64
- self .tokens = ( llama_cpp . llama_token * self . params . n_ctx ) ()
63
+ self .n_threads = n_threads or multiprocessing . cpu_count ()
65
64
66
65
if not os .path .exists (model_path ):
67
66
raise ValueError (f"Model path does not exist: { model_path } " )
@@ -70,6 +69,65 @@ def __init__(
70
69
self .model_path .encode ("utf-8" ), self .params
71
70
)
72
71
72
+ def tokenize (self , text : bytes ) -> List [int ]:
73
+ """Tokenize a string.
74
+
75
+ Args:
76
+ text: The utf-8 encoded string to tokenize.
77
+
78
+ Returns:
79
+ A list of tokens.
80
+ """
81
+ n_ctx = llama_cpp .llama_n_ctx (self .ctx )
82
+ tokens = (llama_cpp .llama_token * n_ctx )()
83
+ n_tokens = llama_cpp .llama_tokenize (
84
+ self .ctx ,
85
+ text ,
86
+ tokens ,
87
+ n_ctx ,
88
+ True ,
89
+ )
90
+ if n_tokens < 0 :
91
+ raise RuntimeError (f"Failed to tokenize: text=\" { text } \" n_tokens={ n_tokens } " )
92
+ return list (tokens [:n_tokens ])
93
+
94
+ def detokenize (self , tokens : List [int ]) -> bytes :
95
+ """Detokenize a list of tokens.
96
+
97
+ Args:
98
+ tokens: The list of tokens to detokenize.
99
+
100
+ Returns:
101
+ The detokenized string.
102
+ """
103
+ output = b""
104
+ for token in tokens :
105
+ output += llama_cpp .llama_token_to_str (self .ctx , token )
106
+ return output
107
+
108
+
109
+ def _eval (self , tokens : List [int ], n_past ):
110
+ rc = llama_cpp .llama_eval (
111
+ self .ctx ,
112
+ (llama_cpp .llama_token * len (tokens ))(* tokens ),
113
+ len (tokens ),
114
+ n_past ,
115
+ self .n_threads ,
116
+ )
117
+ if rc != 0 :
118
+ raise RuntimeError (f"Failed to evaluate: { rc } " )
119
+
120
+ def _sample (self , last_n_tokens , top_p , top_k , temp , repeat_penalty ):
121
+ return llama_cpp .llama_sample_top_p_top_k (
122
+ self .ctx ,
123
+ (llama_cpp .llama_token * len (last_n_tokens ))(* last_n_tokens ),
124
+ len (last_n_tokens ),
125
+ top_k = top_k ,
126
+ top_p = top_p ,
127
+ temp = temp ,
128
+ repeat_penalty = repeat_penalty ,
129
+ )
130
+
73
131
def __call__ (
74
132
self ,
75
133
prompt : str ,
@@ -106,61 +164,38 @@ def __call__(
106
164
"""
107
165
text = b""
108
166
finish_reason = "length"
109
- completion_tokens = 0
167
+ completion_tokens = []
168
+ last_n_tokens = deque ([0 ] * self .last_n , maxlen = self .last_n )
110
169
111
- if stop is not None :
112
- stop = [s .encode ("utf-8" ) for s in stop ]
113
-
114
- prompt_tokens = llama_cpp .llama_tokenize (
115
- self .ctx ,
116
- prompt .encode ("utf-8" ),
117
- self .tokens ,
118
- llama_cpp .llama_n_ctx (self .ctx ),
119
- True ,
120
- )
121
- if prompt_tokens < 0 :
122
- raise RuntimeError (f"Failed to tokenize prompt: { prompt_tokens } " )
170
+ prompt_tokens = self .tokenize (prompt .encode ("utf-8" ))
123
171
124
- if prompt_tokens + max_tokens > self .params . n_ctx :
172
+ if len ( prompt_tokens ) + max_tokens > llama_cpp . llama_n_ctx ( self .ctx ) :
125
173
raise ValueError (
126
174
f"Requested tokens exceed context window of { llama_cpp .llama_n_ctx (self .ctx )} "
127
175
)
128
176
129
177
# Process prompt in chunks to avoid running out of memory
130
- for i in range (0 , prompt_tokens , self .max_chunk_size ):
131
- chunk = self .tokens [i : min (prompt_tokens , i + self .max_chunk_size )]
132
- rc = llama_cpp .llama_eval (
133
- self .ctx ,
134
- (llama_cpp .llama_token * len (chunk ))(* chunk ),
135
- len (chunk ),
136
- max (0 , i - 1 ),
137
- self .n_threads ,
138
- )
139
- if rc != 0 :
140
- raise RuntimeError (f"Failed to evaluate prompt: { rc } " )
178
+ for i in range (0 , len (prompt_tokens ), self .max_chunk_size ):
179
+ chunk = prompt_tokens [i : min (len (prompt_tokens ), i + self .max_chunk_size )]
180
+ self ._eval (chunk , n_past = i )
141
181
142
- for i in range (max_tokens ):
143
- tokens_seen = prompt_tokens + completion_tokens
144
- last_n_tokens = [0 ] * max (0 , self .last_n - tokens_seen ) + [
145
- self .tokens [j ]
146
- for j in range (max (tokens_seen - self .last_n , 0 ), tokens_seen )
147
- ]
182
+ if stop is not None :
183
+ stop = [s .encode ("utf-8" ) for s in stop ]
148
184
149
- token = llama_cpp .llama_sample_top_p_top_k (
150
- self .ctx ,
151
- (llama_cpp .llama_token * len (last_n_tokens ))(* last_n_tokens ),
152
- len (last_n_tokens ),
153
- top_k = top_k ,
185
+ for i in range (max_tokens ):
186
+ token = self ._sample (
187
+ last_n_tokens ,
154
188
top_p = top_p ,
189
+ top_k = top_k ,
155
190
temp = temperature ,
156
- repeat_penalty = repeat_penalty ,
191
+ repeat_penalty = repeat_penalty
157
192
)
158
193
if token == llama_cpp .llama_token_eos ():
159
194
finish_reason = "stop"
160
195
break
161
- text += llama_cpp . llama_token_to_str ( self .ctx , token )
162
- self . tokens [ prompt_tokens + i ] = token
163
- completion_tokens += 1
196
+ text += self .detokenize ([ token ] )
197
+ last_n_tokens . append ( token )
198
+ completion_tokens . append ( token )
164
199
165
200
any_stop = [s for s in stop if s in text ]
166
201
if len (any_stop ) > 0 :
@@ -169,15 +204,7 @@ def __call__(
169
204
finish_reason = "stop"
170
205
break
171
206
172
- rc = llama_cpp .llama_eval (
173
- self .ctx ,
174
- (llama_cpp .llama_token * 1 )(self .tokens [prompt_tokens + i ]),
175
- 1 ,
176
- prompt_tokens + completion_tokens ,
177
- self .n_threads ,
178
- )
179
- if rc != 0 :
180
- raise RuntimeError (f"Failed to evaluate next token: { rc } " )
207
+ self ._eval ([token ], len (prompt_tokens ) + len (completion_tokens ))
181
208
182
209
text = text .decode ("utf-8" )
183
210
@@ -206,9 +233,9 @@ def __call__(
206
233
}
207
234
],
208
235
"usage" : {
209
- "prompt_tokens" : prompt_tokens ,
210
- "completion_tokens" : completion_tokens ,
211
- "total_tokens" : prompt_tokens + completion_tokens ,
236
+ "prompt_tokens" : len ( prompt_tokens ) ,
237
+ "completion_tokens" : len ( completion_tokens ) ,
238
+ "total_tokens" : len ( prompt_tokens ) + len ( completion_tokens ) ,
212
239
},
213
240
}
214
241
0 commit comments