8000 feat: Add Jinja2ChatFormatter · devilcoder01/llama-cpp-python@be09318 · GitHub
[go: up one dir, main page]

Skip to content

Commit be09318

Browse files
committed
feat: Add Jinja2ChatFormatter
1 parent 5a34c57 commit be09318

File tree

1 file changed

+188
-135
lines changed

1 file changed

+188
-135
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 188 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,21 @@ def decorator(f: LlamaChatCompletionHandler):
121121

122122
@dataclasses.dataclass
123123
class ChatFormatterResponse:
124+
"""Dataclass that stores completion parameters for a given chat format and
125+
create_chat_completion request.
126+
127+
prompt contains the formatted prompt generated from the chat format and messages.
128+
stop contains the stop token or list of stop tokens to use for the chat format."""
129+
124130
prompt: str
125131
stop: Optional[Union[str, List[str]]] = None
126132

127133

128134
class ChatFormatter(Protocol):
129135
"""Base Protocol for a chat formatter. A chat formatter is a function that
130-
takes a list of messages and returns a formatted prompt. It can also return
131-
a stop token or list of stop tokens to use for the completion."""
136+
takes a list of messages and returns a chat format response which can be used
137+
to generate a completion. The response can also include a stop token or list
138+
of stop tokens to use for the completion."""
132139

133140
def __call__(
134141
self,
@@ -139,131 +146,43 @@ def __call__(
139146
...
140147

141148

142-
### Utility functions for formatting chat prompts ###
143-
144-
145-
def _get_system_message(
146-
messages: List[llama_types.ChatCompletionRequestMessage],
147-
) -> str:
148-
"""Get the first system message."""
149-
for message in messages:
150-
if message["role"] == "system":
151-
return message["content"] or ""
152-
return ""
153-
154-
155-
def _map_roles(
156-
messages: List[llama_types.ChatCompletionRequestMessage],
157-
role_map: Dict[str, str],
158-
) -> List[Tuple[str, Optional[str]]]:
159-
"""Map the message roles."""
160-
output: List[Tuple[str, Optional[str]]] = []
161-
for message in messages:
162-
role = message["role"]
163-
if role in role_map:
164-
content: str | None = (
165-
message["content"] if isinstance(message["content"], str) else None
166-
)
167-
output.append((role_map[role], content))
168-
return output
169-
170-
171-
def _format_llama2(
172-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
173-
) -> str:
174-
"""Format the prompt with the llama2 style."""
175-
seps = [sep, sep2]
176-
ret = system_message + sep
177-
for i, (role, message) in enumerate(messages):
178-
if system_message and i == 0:
179-
m = message or ""
180-
ret += m + seps[i % 2]
181-
elif message:
182-
ret += role + message + " " + seps[i % 2]
183-
else:
184-
ret += role + " "
185-
return ret
186-
187-
188-
def _format_add_colon_single(
189-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
190-
) -> str:
191-
"""Format the prompt with the add-colon-single style."""
192-
ret = system_message + sep
193-
for role, message in messages:
194-
if message:
195-
ret += role + ": " + message + sep
196-
else:
197-
ret += role + ":"
198-
return ret
199-
200-
201-
def _format_add_colon_two(
202-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
203-
) -> str:
204-
"""Format the prompt with the add-colon-two style."""
205-
seps = [sep, sep2]
206-
ret = system_message + seps[0]
207-
for i, (role, message) in enumerate(messages):
208-
if message:
209-
ret += role + ": " + message + seps[i % 2]
210-
else:
211-
ret += role + ":"
212-
return ret
213-
214-
215-
def _format_no_colon_single(
216-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
217-
) -> str:
218-
"""Format the prompt with the no-colon-single style."""
219-
ret = system_message
220-
for role, message in messages:
221-
if message:
222-
ret += role + message + sep
223-
else:
224-
ret += role
225-
return ret
226-
227-
228-
def _format_add_colon_space_single(
229-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
230-
) -> str:
231-
"""Format the prompt with the add-colon-space-single style."""
232-
ret = system_message + sep
233-
for role, message in messages:
234-
if message:
235-
ret += role + ": " + message + sep
236-
else:
237-
ret += role + ": " # must be end with a space
238-
return ret
239-
149+
class Jinja2ChatFormatter(ChatFormatter):
150+
def __init__(
151+
self,
152+
template: str,
153+
eos_token: str,
154+
bos_token: str,
155+
):
156+
"""A chat formatter that uses jinja2 templates to format the prompt."""
157+
self.template = template
158+
self.eos_token = eos_token
159+
self.bos_token = bos_token
240160

241-
def _format_chatml(
242-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
243-
) -> str:
244-
"""Format the prompt with the chatml style."""
245-
ret = "" if system_message == "" else system_message + sep + "\n"
246-
for role, message in messages:
247-
if message:
248-
ret += role + "\n" + message + sep + "\n"
249-
else:
250-
ret += role + "\n"
251-
return ret
161+
self._environment = jinja2.Environment(
162+
loader=jinja2.BaseLoader(),
163+
trim_blocks=True,
164+
lstrip_blocks=True,
165+
).from_string(self.template)
252166

167+
def __call__(
168+
self,
169+
*,
170+
messages: List[llama_types.ChatCompletionRequestMessage],
171+
**kwargs: Any,
172+
) -> ChatFormatterResponse:
173+
messages = [
174+
*messages,
175+
llama_types.ChatCompletionRequestAssistantMessage(
176+
role="assistant", content=""
177+
),
178+
]
179+
prompt = self._environment.render(
180+
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
181+
)
182+
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
253183

254-
def _format_chatglm3(
255-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
256-
) -> str:
257-
"""Format the prompt with the chatglm3 style."""
258-
ret = ""
259-
if system_message:
260-
ret += system_message
261-
for role, message in messages:
262-
if message:
263-
ret += role + "\n" + " " + message
264-
else:
265-
ret += role
266-
return ret
184+
def to_chat_handler(self) -> LlamaChatCompletionHandler:
185+
return chat_formatter_to_chat_completion_handler(self)
267186

268187

269188
def _convert_text_completion_to_chat(
@@ -426,16 +345,6 @@ def chat_completion_handler(
426345
return chat_completion_handler
427346

428347

429-
def register_chat_format(name: str):
430-
def decorator(f: ChatFormatter):
431-
chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
432-
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
433-
name, chat_completion_handler
434-
)
435-
return f
436-
return decorator
437-
438-
439348
def hf_autotokenizer_to_chat_formatter(
440349
pretrained_model_name_or_path: Union[str, os.PathLike[str]]
441350
) -> ChatFormatter:
@@ -466,7 +375,9 @@ def hf_autotokenizer_to_chat_completion_handler(
466375
return chat_formatter_to_chat_completion_handler(chat_formatter)
467376

468377

469-
def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter:
378+
def hf_tokenizer_config_to_chat_formatter(
379+
tokenizer_config: Dict[str, Any]
380+
) -> ChatFormatter:
470381
assert isinstance(tokenizer_config, dict)
471382

472383
assert "chat_template" in tokenizer_config
@@ -504,6 +415,7 @@ def format_autotokenizer(
504415
eos_token=eos_token,
505416
)
506417
return ChatFormatterResponse(prompt=prompt, stop=eos_token)
418+
507419
return format_autotokenizer
508420

509421

@@ -514,6 +426,147 @@ def hf_tokenizer_config_to_chat_completion_handler(
514426
return chat_formatter_to_chat_completion_handler(chat_formatter)
515427

516428

429+
### Utility functions for formatting chat prompts ###
430+
431+
432+
def _get_system_message(
433+
messages: List[llama_types.ChatCompletionRequestMessage],
434+
) -> str:
435+
"""Get the first system message."""
436+
for message in messages:
437+
if message["role"] == "system":
438+
return message["content"] or ""
439+
return ""
440+
441+
442+
def _map_roles(
443+
messages: List[llama_types.ChatCompletionRequestMessage],
444+
role_map: Dict[str, str],
445+
) -> List[Tuple[str, Optional[str]]]:
446+
"""Map the message roles."""
447+
output: List[Tuple[str, Optional[str]]] = []
448+
for message in messages:
449+
role = message["role"]
450+
if role in role_map:
451+
content: str | None = (
452+
message["content"] if isinstance(message["content"], str) else None
453+
)
454+
output.append((role_map[role], content))
455+
return output
456+
457+
458+
def _format_llama2(
459+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
460+
) -> str:
461+
"""Format the prompt with the llama2 style."""
462+
seps = [sep, sep2]
463+
ret = system_message + sep
464+
for i, (role, message) in enumerate(messages):
465+
if system_message and i == 0:
466+
m = message or ""
467+
ret += m + seps[i % 2]
468+
elif message:
469+
ret += role + message + " " + seps[i % 2]
470+
else:
471+
ret += role + " "
472+
return ret
473+
474+
475+
def _format_add_colon_single(
476+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
477+
) -> str:
478+
"""Format the prompt with the add-colon-single style."""
479+
ret = system_message + sep
480+
for role, message in messages:
481+
if message:
482+
ret += role + ": " + message + sep
483+
else:
484+
ret += role + ":"
485+
return ret
486+
487+
488+
def _format_add_colon_two(
489+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
490+
) -> str:
491+
"""Format the prompt with the add-colon-two style."""
492+
seps = [sep, sep2]
493+
ret = system_message + seps[0]
494+
for i, (role, message) in enumerate(messages):
495+
if message:
496+
ret += role + ": " + message + seps[i % 2]
497+
else:
498+
ret += role + ":"
499+
return ret
500+
501+
502+
def _format_no_colon_single(
503+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
504+
) -> str:
505+
"""Format the prompt with the no-colon-single style."""
506+
ret = system_message
507+
for role, message in messages:
508+
if message:
509+
ret += role + message + sep
510+
else:
511+
ret += role
512+
return ret
513+
514+
515+
def _format_add_colon_space_single(
516+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
517+
) -> str:
518+
"""Format the prompt with the add-colon-space-single style."""
519+
ret = system_message + sep
520+
for role, message in messages:
521+
if message:
522+
ret += role + ": " + message + sep
523+
else:
524+
ret += role + ": " # must be end with a space
525+
return ret
526+
527+
528+
def _format_chatml(
529+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
530+
) -> str:
531+
"""Format the prompt with the chatml style."""
532+
ret = "" if system_message == "" else system_message + sep + "\n"
533+
for role, message in messages:
534+
if message:
535+
ret += role + "\n" + message + sep + "\n"
536+
else:
537+
ret += role + "\n"
538+
return ret
539+
540+
541+
def _format_chatglm3(
542+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
543+
) -> str:
544+
"""Format the prompt with the chatglm3 style."""
545+
ret = ""
546+
if system_message:
547+
ret += system_message
548+
for role, message in messages:
549+
if message:
550+
ret += role + "\n" + " " + message
551+
else:
552+
ret += role
553+
return ret
554+
555+
556+
### Chat Formats ###
557+
558+
559+
def register_chat_format(name: str):
560+
def decorator(f: ChatFormatter):
561+
chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
562+
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
563+
name, chat_completion_handler
564+
)
565+
return f
566+
567+
return decorator
568+
569+
517570
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
518571
# system prompt is "embedded" in the first message
519572
@register_chat_format("llama-2")

0 commit comments

Comments
 (0)
0