@@ -121,14 +121,21 @@ def decorator(f: LlamaChatCompletionHandler):
121
121
122
122
@dataclasses .dataclass
123
123
class ChatFormatterResponse :
124
+ """Dataclass that stores completion parameters for a given chat format and
125
+ create_chat_completion request.
126
+
127
+ prompt contains the formatted prompt generated from the chat format and messages.
128
+ stop contains the stop token or list of stop tokens to use for the chat format."""
129
+
124
130
prompt : str
125
131
stop : Optional [Union [str , List [str ]]] = None
126
132
127
133
128
134
class ChatFormatter (Protocol ):
129
135
"""Base Protocol for a chat formatter. A chat formatter is a function that
130
- takes a list of messages and returns a formatted prompt. It can also return
131
- a stop token or list of stop tokens to use for the completion."""
136
+ takes a list of messages and returns a chat format response which can be used
137
+ to generate a completion. The response can also include a stop token or list
138
+ of stop tokens to use for the completion."""
132
139
133
140
def __call__ (
134
141
self ,
@@ -139,131 +146,43 @@ def __call__(
139
146
...
140
147
141
148
142
- ### Utility functions for formatting chat prompts ###
143
-
144
-
145
- def _get_system_message (
146
- messages : List [llama_types .ChatCompletionRequestMessage ],
147
- ) -> str :
148
- """Get the first system message."""
149
- for message in messages :
150
- if message ["role" ] == "system" :
151
- return message ["content" ] or ""
152
- return ""
153
-
154
-
155
- def _map_roles (
156
- messages : List [llama_types .ChatCompletionRequestMessage ],
157
- role_map : Dict [str , str ],
158
- ) -> List [Tuple [str , Optional [str ]]]:
159
- """Map the message roles."""
160
- output : List [Tuple [str , Optional [str ]]] = []
161
- for message in messages :
162
- role = message ["role" ]
163
- if role in role_map :
164
- content : str | None = (
165
- message ["content" ] if isinstance (message ["content" ], str ) else None
166
- )
167
- output .append ((role_map [role ], content ))
168
- return output
169
-
170
-
171
- def _format_llama2 (
172
- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str , sep2 : str
173
- ) -> str :
174
- """Format the prompt with the llama2 style."""
175
- seps = [sep , sep2 ]
176
- ret = system_message + sep
177
- for i , (role , message ) in enumerate (messages ):
178
- if system_message and i == 0 :
179
- m = message or ""
180
- ret += m + seps [i % 2 ]
181
- elif message :
182
- ret += role + message + " " + seps [i % 2 ]
183
- else :
184
- ret += role + " "
185
- return ret
186
-
187
-
188
- def _format_add_colon_single (
189
- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
190
- ) -> str :
191
- """Format the prompt with the add-colon-single style."""
192
- ret = system_message + sep
193
- for role , message in messages :
194
- if message :
195
- ret += role + ": " + message + sep
196
- else :
197
- ret += role + ":"
198
- return ret
199
-
200
-
201
- def _format_add_colon_two (
202
- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str , sep2 : str
203
- ) -> str :
204
- """Format the prompt with the add-colon-two style."""
205
- seps = [sep , sep2 ]
206
- ret = system_message + seps [0 ]
207
- for i , (role , message ) in enumerate (messages ):
208
- if message :
209
- ret += role + ": " + message + seps [i % 2 ]
210
- else :
211
- ret += role + ":"
212
- return ret
213
-
214
-
215
- def _format_no_colon_single (
216
- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
217
- ) -> str :
218
- """Format the prompt with the no-colon-single style."""
219
- ret = system_message
220
- for role , message in messages :
221
- if message :
222
- ret += role + message + sep
223
- else :
224
- ret += role
225
- return ret
226
-
227
-
228
- def _format_add_colon_space_single (
229
- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
230
- ) -> str :
231
- """Format the prompt with the add-colon-space-single style."""
232
- ret = system_message + sep
233
- for role , message in messages :
234
- if message :
235
- ret += role + ": " + message + sep
236
- else :
237
- ret += role + ": " # must be end with a space
238
- return ret
239
-
149
+ class Jinja2ChatFormatter (ChatFormatter ):
150
+ def __init__ (
151
+ self ,
152
+ template : str ,
153
+ eos_token : str ,
154
+ bos_token : str ,
155
+ ):
156
+ """A chat formatter that uses jinja2 templates to format the prompt."""
157
+ self .template = template
158
+ self .eos_token = eos_token
159
+ self .bos_token = bos_token
240
160
241
- def _format_chatml (
242
- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
243
- ) -> str :
244
- """Format the prompt with the chatml style."""
245
- ret = "" if system_message == "" else system_message + sep + "\n "
246
- for role , message in messages :
247
- if message :
248
- ret += role + "\n " + message + sep + "\n "
249
- else :
250
- ret += role + "\n "
251
- return ret
161
+ self ._environment = jinja2 .Environment (
162
+ loader = jinja2 .BaseLoader (),
163
+ trim_blocks = True ,
164
+ lstrip_blocks = True ,
165
+ ).from_string (self .template )
252
166
167
+ def __call__ (
168
+ self ,
169
+ * ,
170
+ messages : List [llama_types .ChatCompletionRequestMessage ],
171
+ ** kwargs : Any ,
172
+ ) -> ChatFormatterResponse :
173
+ messages = [
174
+ * messages ,
175
+ llama_types .ChatCompletionRequestAssistantMessage (
176
+ role = "assistant" , content = ""
177
+ ),
178
+ ]
179
+ prompt = self ._environment .render (
180
+ messages = messages , eos_token = self .eos_token , bos_token = self .bos_token
181
+ )
182
+ return ChatFormatterResponse (prompt = prompt , stop = [self .eos_token ])
253
183
254
- def _format_chatglm3 (
255
- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
256
- ) -> str :
257
- """Format the prompt with the chatglm3 style."""
258
- ret = ""
259
- if system_message :
260
- ret += system_message
261
- for role , message in messages :
262
- if message :
263
- ret += role + "\n " + " " + message
264
- else :
265
- ret += role
266
- return ret
184
+ def to_chat_handler (self ) -> LlamaChatCompletionHandler :
185
+ return chat_formatter_to_chat_completion_handler (self )
267
186
268
187
269
188
def _convert_text_completion_to_chat (
@@ -426,16 +345,6 @@ def chat_completion_handler(
426
345
return chat_completion_handler
427
346
428
347
429
- def register_chat_format (name : str ):
430
- def decorator (f : ChatFormatter ):
431
- chat_completion_handler = chat_formatter_to_chat_completion_handler (f )
432
- LlamaChatCompletionHandlerRegistry ().register_chat_completion_handler (
433
- name , chat_completion_handler
434
- )
435
- return f
436
- return decorator
437
-
438
-
439
348
def hf_autotokenizer_to_chat_formatter (
440
349
pretrained_model_name_or_path : Union [str , os .PathLike [str ]]
441
350
) -> ChatFormatter :
@@ -466,7 +375,9 @@ def hf_autotokenizer_to_chat_completion_handler(
466
375
return chat_formatter_to_chat_completion_handler (chat_formatter )
467
376
468
377
469
- def hf_tokenizer_config_to_chat_formatter (tokenizer_config : Dict [str , Any ]) -> ChatFormatter :
378
+ def hf_tokenizer_config_to_chat_formatter (
379
+ tokenizer_config : Dict [str , Any ]
380
+ ) -> ChatFormatter :
470
381
assert isinstance (tokenizer_config , dict )
471
382
472
383
assert "chat_template" in tokenizer_config
@@ -504,6 +415,7 @@ def format_autotokenizer(
504
415
eos_token = eos_token ,
505
416
)
506
417
return ChatFormatterResponse (prompt = prompt , stop = eos_token )
418
+
507
419
return format_autotokenizer
508
420
509
421
@@ -514,6 +426,147 @@ def hf_tokenizer_config_to_chat_completion_handler(
514
426
return chat_formatter_to_chat_completion_handler (chat_formatter )
515
427
516
428
429
+ ### Utility functions for formatting chat prompts ###
430
+
431
+
432
+ def _get_system_message (
433
+ messages : List [llama_types .ChatCompletionRequestMessage ],
434
+ ) -> str :
435
+ """Get the first system message."""
436
+ for message in messages :
437
+ if message ["role" ] == "system" :
438
+ return message ["content" ] or ""
439
+ return ""
440
+
441
+
442
+ def _map_roles (
443
+ messages : List [llama_types .ChatCompletionRequestMessage ],
444
+ role_map : Dict [str , str ],
445
+ ) -> List [Tuple [str , Optional [str ]]]:
446
+ """Map the message roles."""
447
+ output : List [Tuple [str , Optional [str ]]] = []
448
+ for message in messages :
449
+ role = message ["role" ]
450
+ if role in role_map :
451
+ content : str | None = (
452
+ message ["content" ] if isinstance (message ["content" ], str ) else None
453
+ )
454
+ output .append ((role_map [role ], content ))
455
+ return output
456
+
457
+
458
+ def _format_llama2 (
459
+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str , sep2 : str
460
+ ) -> str :
461
+ """Format the prompt with the llama2 style."""
462
+ seps = [sep , sep2 ]
463
+ ret = system_message + sep
464
+ for i , (role , message ) in enumerate (messages ):
465
+ if system_message and i == 0 :
466
+ m = message or ""
467
+ ret += m + seps [i % 2 ]
468
+ elif message :
469
+ ret += role + message + " " + seps [i % 2 ]
470
+ else :
471
+ ret += role + " "
472
+ return ret
473
+
474
+
475
+ def _format_add_colon_single (
476
+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
477
+ ) -> str :
478
+ """Format the prompt with the add-colon-single style."""
479
+ ret = system_message + sep
480
+ for role , message in messages :
481
+ if message :
482
+ ret += role + ": " + message + sep
483
+ else :
484
+ ret += role + ":"
485
+ return ret
486
+
487
+
488
+ def _format_add_colon_two (
489
+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str , sep2 : str
490
+ ) -> str :
491
+ """Format the prompt with the add-colon-two style."""
492
+ seps = [sep , sep2 ]
493
+ ret = system_message + seps [0 ]
494
+ for i , (role , message ) in enumerate (messages ):
495
+ if message :
496
+ ret += role + ": " + message + seps [i % 2 ]
497
+ else :
498
+ ret += role + ":"
499
+ return ret
500
+
501
+
502
+ def _format_no_colon_single (
503
+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
504
+ ) -> str :
505
+ """Format the prompt with the no-colon-single style."""
506
+ ret = system_message
507
+ for role , message in messages :
508
+ if message :
509
+ ret += role + message + sep
510
+ else :
511
+ ret += role
512
+ return ret
513
+
514
+
515
+ def _format_add_colon_space_single (
516
+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
517
+ ) -> str :
518
+ """Format the prompt with the add-colon-space-single style."""
519
+ ret = system_message + sep
520
+ for role , message in messages :
521
+ if message :
522
+ ret += role + ": " + message + sep
523
+ else :
524
+ ret += role + ": " # must be end with a space
525
+ return ret
526
+
527
+
528
+ def _format_chatml (
529
+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
530
+ ) -> str :
531
+ """Format the prompt with the chatml style."""
532
+ ret = "" if system_message == "" else system_message + sep + "\n "
533
+ for role , message in messages :
534
+ if message :
535
+ ret += role + "\n " + message + sep + "\n "
536
+ else :
537
+ ret += role + "\n "
538
+ return ret
539
+
540
+
541
+ def _format_chatglm3 (
542
+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
543
+ ) -> str :
544
+ """Format the prompt with the chatglm3 style."""
545
+ ret = ""
546
+ if system_message :
547
+ ret += system_message
548
+ for role , message in messages :
549
+ if message :
550
+ ret += role + "\n " + " " + message
551
+ else :
552
+ ret += role
553
+ return ret
554
+
555
+
556
+ ### Chat Formats ###
557
+
558
+
559
+ def register_chat_format (name : str ):
560
+ def decorator (f : ChatFormatter ):
561
+ chat_completion_handler = chat_formatter_to_chat_completion_handler (f )
562
+ LlamaChatCompletionHandlerRegistry ().register_chat_completion_handler (
563
+ name , chat_completion_handler
564
+ )
565
+ return f
566
+
567
+ return decorator
568
+
569
+
517
570
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
518
571
# system prompt is "embedded" in the first message
519
572
@register_chat_format ("llama-2" )
0 commit comments