File tree Expand file tree Collapse file tree 1 file changed +9
-6
lines changed Expand file tree Collapse file tree 1 file changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -152,11 +152,13 @@ def __init__(
152
152
template : str ,
153
153
eos_token : str ,
154
154
bos_token : str ,
155
+ add_generation_prompt : bool = True ,
155
156
):
156
157
"""A chat formatter that uses jinja2 templates to format the prompt."""
157
158
self .template = template
158
159
self .eos_token = eos_token
159
160
self .bos_token = bos_token
161
+ self .add_generation_prompt = add_generation_prompt
160
162
161
163
self ._environment = jinja2 .Environment (
162
164
loader = jinja2 .BaseLoader (),
@@ -170,12 +172,13 @@ def __call__(
170
172
messages : List [llama_types .ChatCompletionRequestMessage ],
171
173
** kwargs : Any ,
172
174
) -> ChatFormatterResponse :
173
- messages = [
174
- * messages ,
175
- llama_types .ChatCompletionRequestAssistantMessage (
176
- role = "assistant" , content = ""
177
- ),
178
- ]
175
+ if self .add_generation_prompt :
176
+ messages = [
177
+ * messages ,
178
+ llama_types .ChatCompletionRequestAssistantMessage (
179
+ role = "assistant" , content = ""
180
+ ),
181
+ ]
179
182
prompt = self ._environment .render (
180
183
messages = messages , eos_token = self .eos_token , bos_token = self .bos_token
181
184
)
You can’t perform that action at this time.
0 commit comments