@@ -33,6 +33,7 @@ def __init__(self,
33
33
top_p : float = 1. ,
34
34
temp : float = 1.0 ,
35
35
repeat_penalty : float = 1 ,
36
+ init_break : bool = True ,
36
37
instruct_inp_prefix : str = "\n \n ### Instruction:\n \n " ,
37
38
instruct_inp_suffix : str = "\n \n ### Response:\n \n " ,
38
39
) -> None :
@@ -48,6 +49,7 @@ def __init__(self,
48
49
self .top_p = top_p
49
50
self .temp = temp
50
51
self .repeat_penalty = repeat_penalty
52
+ self .init_break = init_break
51
53
52
54
# runtime args
53
55
self .input_consumed = 0
@@ -81,9 +83,6 @@ def __init__(self,
81
83
if (len (primer ) > 0 ):
82
84
self .embd_inp += self ._tokenize (primer )
83
85
84
- # break immediately if using instruct
85
- self .init_break = self .instruct
86
-
87
86
# number of tokens to keep when resetting context
88
87
if (self .n_keep < 0 or self .n_keep > len (self .embd_inp ) or self .instruct ):
89
88
self .n_keep = len (self .embd_inp )
@@ -182,13 +181,14 @@ def generate(self):
182
181
if (len (self .embd_inp ) <= self .input_consumed ):
183
182
# if antiprompt is present, stop
184
183
if (self .use_antiprompt ()):
185
- for i in self .first_antiprompt :
186
- if i == self .last_n_tokens [- len (i ):]:
187
- return
184
+ if True in [
185
+ i == self .last_n_tokens [- len (i ):]
186
+ for i in self .first_antiprompt
187
+ ]:
188
+ break
188
189
189
190
# if we are using instruction mode, and we have processed the initial prompt
190
191
if (self .init_break ):
191
- self .init_break = False
192
192
break
193
193
194
194
# if end of generation
@@ -201,6 +201,8 @@ def generate(self):
201
201
self .embd_inp += self .first_antiprompt [0 ]
202
202
break
203
203
204
+ self .init_break = False
205
+
204
206
def __enter__ (self ):
205
207
return self
206
208
0 commit comments