1
+ """
2
+ This is an example implementation of main.cpp from llama.cpp
3
+ Quirks:
4
+ * Its not exactly alike since this port is designed around programmatic I/O
5
+ * Input is always echoed if on, so it should be turned off when using "input()"
6
+ * The first antiprompt should be the userprompt like "\n User:",
7
+ because its added when n_predict is reached (aka generation ended prematurely)
8
+ * n_predict can be set to -1 for unlimited length responses
9
+ """
10
+ import llama_cpp
11
+
12
+ def toIntArray (lst ):
13
+ return [int (i ) for i in lst ]
14
+
15
+ # A LLaMA interactive session
16
+ class LLaMAInteract :
17
+ def __init__ (self ,
18
+ primer : str = "" ,
19
+ model : str = "./models/30B/ggml-model-q4_0.bin" ,
20
+ n_ctx : int = 1024 ,
21
+ seed : int = 0 ,
22
+ n_threads : int = 8 ,
23
+ antiprompt : list [str ]= [],
24
+ input_echo : bool = True ,
25
+ n_predict : int = 20 ,
26
+ n_batch : int = 8 ,
27
+ repeat_last_n : int = 64 ,
28
+ top_k : int = 50 ,
29
+ top_p : float = 1. ,
30
+ temp : float = 1.0 ,
31
+ repeat_penalty : float = 1 ,
32
+ ) -> None :
33
+ # input args
34
+ self .n_threads = n_threads
35
+ self .input_echo = input_echo
36
+ self .n_predict = n_predict
37
+ self .n_batch = n_batch
38
+ self .repeat_last_n = repeat_last_n
39
+ self .top_k = top_k
40
+ self .top_p = top_p
41
+ self .temp = temp
42
+ self .repeat_penalty = repeat_penalty
43
+ self .n_ctx = n_ctx
44
+ self .seed = seed
45
+
46
+ # runtime args
47
+ self .input_consumed = 0
48
+ self .embd = []
49
+ self .embd_inp = []
50
+ self .n_past = 0
51
+ self .first_antiprompt = []
52
+ self .remaining_tokens = self .n_predict
53
+ self .output_echo = input_echo
54
+
55
+ # model load
56
+ self .lparams = llama_cpp .llama_context_default_params ()
57
+ self .lparams .n_ctx = self .n_ctx
58
+ self .lparams .seed = self .seed
59
+ self .ctx = llama_cpp .llama_init_from_file (model .encode ("utf8" ), self .lparams )
60
+
61
+ # determine the required inference memory per token:
62
+ tmp = [0 , 1 , 2 , 3 ]
63
+ llama_cpp .llama_eval (self .ctx , (llama_cpp .c_int * len (tmp ))(* tmp ), len (tmp ), 0 , self .n_threads )
64
+
65
+ # determine newline token
66
+ self .llama_token_newline = (llama_cpp .llama_token * 1 )()
67
+ llama_cpp .llama_tokenize (self .ctx , b"\n " , self .llama_token_newline , len (self .llama_token_newline ), False )
68
+ self .llama_token_newline = toIntArray (self .llama_token_newline )
69
+
70
+ # primer feed
71
+ if (len (primer ) > 0 ):
72
+ self .input (primer )
73
+ self .n_keep = len (self .embd_inp )
74
+
75
+ # create internal context
76
+ self .n_ctx = int (llama_cpp .llama_n_ctx (self .ctx ))
77
+ self .last_n_tokens = [0 ]* self .n_ctx #TODO: deque doesnt support slices
78
+
79
+ # determine antiprompt tokens
80
+ for i in antiprompt :
81
+ d_antiprompt = (llama_cpp .llama_token * (len (i ) + 1 ))()
82
+ n_antiprompt = llama_cpp .llama_tokenize (self .ctx , i .encode ("utf8" ), d_antiprompt , len (d_antiprompt ), False )
83
+ self .first_antiprompt .append (toIntArray (d_antiprompt [:n_antiprompt ]))
84
+
85
+ # if an antiprompt is present
86
+ def use_antiprompt (self ):
87
+ return len (self .first_antiprompt ) > 0
88
+
89
+ def generate (self ):
90
+ while self .remaining_tokens > 0 or self .use_antiprompt ():
91
+ # predict
92
+ if len (self .embd ) > 0 :
93
+ # infinite text generation via context swapping
94
+ # if we run out of context:
95
+ # - take the n_keep first tokens from the original prompt (via n_past)
96
+ # - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
97
+ if (self .n_past + len (self .embd ) > self .n_ctx ):
98
+ n_left = self .n_past - self .n_keep
99
+ self .n_past = self .n_keep
100
+
101
+ # insert n_left/2 tokens at the start of embd from last_n_tokens
102
+ _insert = self .last_n_tokens [
103
+ - (int (n_left / 2 ) - len (self .embd )):- len (self .embd )
104
+ ]
105
+ self .embd [:len (_insert )] = _insert
106
+ #TODO: Still untested
107
+
108
+ if (llama_cpp .llama_eval (
109
+ self .ctx , (llama_cpp .llama_token * len (self .embd ))(* self .embd ), len (self .embd ), self .n_past , self .n_threads
110
+ ) != 0 ):
111
+ raise Exception ("Failed to llama_eval!" )
112
+
113
+ self .n_past += len (self .embd )
114
+ self .embd = []
115
+ if len (self .embd_inp ) <= self .input_consumed :
116
+ # out of user input, sample next token
117
+ _arr = self .last_n_tokens [- min (self .repeat_last_n , self .n_past ):]
118
+ id = llama_cpp .llama_sample_top_p_top_k (
119
+ self .ctx ,
120
+ (llama_cpp .llama_token * len (_arr ))(* _arr ),
121
+ len (_arr ),
122
+ self .top_k ,
123
+ self .top_p ,
124
+ self .temp ,
125
+ self .repeat_penalty ,
126
+ )
127
+ self .last_n_tokens .pop (0 )
128
+ self .last_n_tokens .append (int (id ))
129
+
130
+ # replace end of text token with newline token when in interactive mode
131
+ if (id == llama_cpp .llama_token_eos () and self .use_antiprompt ()):
132
+ id = self .llama_token_newline [0 ]
133
+ # tokenize and inject first reverse prompt
134
+ self .embd_inp += self .first_antiprompt [0 ]
135
+
136
+ # add it to the context
137
+ self .embd .append (int (id ))
138
+
139
+ # echo this to console
140
+ self .output_echo = True
141
+
142
+ # decrement remaining sampling budget
143
+ self .remaining_tokens -= 1
144
+ else :
145
+ # output to console if input echo is on
146
+ self .output_echo = self .input_echo
147
+
148
+ # some user input remains from prompt or interaction, forward it to processing
149
+ while len (self .embd_inp ) > self .input_consumed :
150
+ self .embd .append (int (self .embd_inp [self .input_consumed ]))
151
+ self .last_n_tokens .pop (0 )
152
+ self .last_n_tokens .append (int (self .embd_inp [self .input_consumed ]))
153
+ self .input_consumed += 1
154
+ if len (self .embd ) >= self .n_batch :
155
+ break
156
+
157
+ # display tokens
158
+ if self .output_echo :
159
+ for id in self .embd :
160
+ yield id
161
+
162
+ # if antiprompt is present, stop
163
+ if (self .use_antiprompt () and len (self .embd_inp ) <= self .input_consumed ):
164
+ for i in self .first_antiprompt :
165
+ if i == self .last_n_tokens [- len (i ):]:
166
+ return
167
+
168
+ # if end of generation
169
+ if len (self .embd ) > 0 and self .embd [- 1 ] == llama_cpp .llama_token_eos ():
170
+ break
171
+
172
+ # respect n_predict even if antiprompt is present
173
+ if (self .use_antiprompt () and self .remaining_tokens <= 0 and self .n_predict != - 1 ):
174
+ self .embd_inp += self .first_antiprompt [0 ]
175
+ break
176
+
177
+ def past (self ):
178
+ for id in self .last_n_tokens [- self .n_past :]:
179
+ yield llama_cpp .llama_token_to_str (self .ctx , id ).decode ("utf-8" )
180
+
181
+ def input (self , prompt : str ):
182
+ embd_arr = (llama_cpp .llama_token * (len (prompt ) + 1 ))()
183
+ n_of_tok = llama_cpp .llama_tokenize (self .ctx , prompt .encode ("utf8" ), embd_arr , len (embd_arr ), True )
184
+ self .embd_inp += toIntArray (embd_arr [:n_of_tok ])
185
+
186
+ def output (self ):
187
+ self .remaining_tokens = self .n_predict
188
+ for id in self .generate ():
189
+ yield llama_cpp .llama_token_to_str (self .ctx , id ).decode ("utf-8" )
190
+
191
+ if __name__ == "__main__" :
192
+ from datetime import datetime
193
+
194
+ USER_NAME = "User"
195
+ AI_NAME = "ChatLLaMa"
196
+
197
+ time_now = datetime .now ()
198
+ prompt = f"""Text transcript of a never ending dialog, where { USER_NAME } interacts with an AI assistant named { AI_NAME } .
199
+ { AI_NAME } is helpful, kind, honest, friendly, good at writing and never fails to answer { USER_NAME } ’s requests immediately and with details and precision.
200
+ There are no annotations like (30 seconds passed...) or (to himself), just what { USER_NAME } and { AI_NAME } say aloud to each other.
201
+ The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
202
+ The transcript only includes text, it does not include markup like HTML and Markdown.
203
+
204
+ { USER_NAME } : Hello, { AI_NAME } !
205
+ { AI_NAME } : Hello { USER_NAME } ! How may I help you today?
206
+ { USER_NAME } : What time is it?
207
+ { AI_NAME } : It is { time_now .strftime ("%H:%M" )} .
208
+ { USER_NAME } : What year is it?
209
+ { AI_NAME } : We are in { time_now .strftime ("%Y" )} .
210
+ { USER_NAME } : What is a cat?
211
+ { AI_NAME } : A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
212
+ { USER_NAME } : Name a color.
213
+ { AI_NAME } : Blue
214
+ { USER_NAME } :"""
215
+
216
+ print ("Loading model..." )
217
+ ll = LLaMAInteract (prompt ,
218
+ model = "./models/30B/ggml-model-q4_0.bin" ,
219
+ n_ctx = 2048 ,
220
+ antiprompt = [f"\n { USER_NAME } :" ],
221
+ repeat_last_n = 256 ,
222
+ n_predict = 2048 ,
223
+ temp = 0.7 , top_p = 0.5 , top_k = 40 , repeat_penalty = 1.17647
224
+ )
225
+ print ("Loaded model!" )
226
+
227
+ for i in ll .output ():
228
+ print (i ,end = "" ,flush = True )
229
+ ll .input_echo = False
230
+
231
+ inp = lambda x : f" { x } \n "
232
+ while True :
233
+ ll .input (inp (input (' ' )))
234
+ for i in ll .output ():
235
+ print (i ,end = "" ,flush = True )
0 commit comments