@@ -19,6 +19,9 @@ def __init__(
19
19
):
20
20
self .model_path = model_path
21
21
22
+ self .last_n = 64
23
+ self .max_chunk_size = 32
24
+
22
25
self .params = llama_cpp .llama_context_default_params ()
23
26
self .params .n_ctx = n_ctx
24
27
self .params .n_parts = n_parts
@@ -59,21 +62,32 @@ def __call__(
59
62
self .ctx , prompt .encode ("utf-8" ), self .tokens , self .params .n_ctx , True
60
63
)
61
64
62
- if prompt_tokens + max_tokens > self .params . n_ctx :
65
+ if prompt_tokens + max_tokens > llama_cpp . llama_n_ctx ( self .ctx ) :
63
66
raise ValueError (
64
67
f"Requested tokens exceed context window of { self .params .n_ctx } "
65
68
)
66
69
67
- for i in range (prompt_tokens ):
68
- llama_cpp .llama_eval (
69
- self .ctx , (llama_cpp .c_int * 1 )(self .tokens [i ]), 1 , i , self .n_threads
70
+ # Process prompt in chunks to avoid running out of memory
71
+ for i in range (0 , prompt_tokens , self .max_chunk_size ):
72
+ chunk = self .tokens [i : min (prompt_tokens , i + self .max_chunk_size )]
73
+ rc = llama_cpp .llama_eval (
74
+ self .ctx ,
75
+ (llama_cpp .llama_token * len (chunk ))(* chunk ),
76
+ len (chunk ),
77
+ max (0 , i - 1 ),
78
+ self .n_threads ,
70
79
)
80
+ if rc != 0 :
81
+ raise RuntimeError (f"Failed to evaluate prompt: { rc } " )
71
82
72
83
for i in range (max_tokens ):
84
+ tokens_seen = prompt_tokens + completion_tokens
85
+ last_n_tokens = [0 ] * max (0 , self .last_n - tokens_seen ) + [self .tokens [j ] for j in range (max (tokens_seen - self .last_n , 0 ), tokens_seen )]
86
+
73
87
token = llama_cpp .llama_sample_top_p_top_k (
74
88
self .ctx ,
75
- self . tokens ,
76
- prompt_tokens + completion_tokens ,
89
+ ( llama_cpp . llama_token * len ( last_n_tokens ))( * last_n_tokens ) ,
90
+ len ( last_n_tokens ) ,
77
91
top_k = top_k ,
78
92
top_p = top_p ,
79
93
temp = temperature ,
@@ -82,7 +96,6 @@ def __call__(
82
96
if token == llama_cpp .llama_token_eos ():
83
97
finish_reason = "stop"
84
98
break
85
- # text += llama_cpp.llama_token_to_str(self.ctx, token).decode("utf-8")
86
99
text += llama_cpp .llama_token_to_str (self .ctx , token )
87
100
self .tokens [prompt_tokens + i ] = token
88
101
completion_tokens += 1
@@ -96,7 +109,7 @@ def __call__(
96
109
97
110
llama_cpp .llama_eval (
98
111
self .ctx ,
99
- (llama_cpp .c_int * 1 )(self .tokens [prompt_tokens + i ]),
112
+ (llama_cpp .llama_token * 1 )(self .tokens [prompt_tokens + i ]),
100
113
1 ,
101
114
prompt_tokens + completion_tokens ,
102
115
self .n_threads ,
0 commit comments