8000 Add chat format test. · notwa/llama-cpp-python@9ae5819 · GitHub
[go: up one dir, main page]

Skip to content
< 10000 header class="HeaderMktg header-logged-out js-details-container js-header Details f4 py-3" role="banner" data-is-top="true" data-color-mode=light data-light-theme=light data-dark-theme=dark>

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9ae5819

Browse files
committed
Add chat format test.
1 parent ce38dbd commit 9ae5819

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -878,19 +878,21 @@ def format_chatml(
878878

879879

880880
@register_chat_format("mistral-instruct")
881-
def format_mistral(
881+
def format_mistral_instruct(
882882
messages: List[llama_types.ChatCompletionRequestMessage],
883883
**kwargs: Any,
884884
) -> ChatFormatterResponse:
885-
_roles = dict(user="[INST] ", assistant="[/INST]")
886-
_sep = " "
887-
system_template = """<s>{system_message}"""
888-
system_message = _get_system_message(messages)
889-
system_message = system_template.format(system_message=system_message)
890-
_messages = _map_roles(messages, _roles)
891-
_messages.append((_roles["assistant"], None))
892-
_prompt = _format_no_colon_single(system_message, _messages, _sep)
893-
return ChatFormatterResponse(prompt=_prompt)
885+
bos = "<s>"
886+
eos = "</s>"
887+
stop = eos
888+
prompt = bos
889+
for message in messages:
890+
if message["role"] == "user" and message["content"] is not None and isinstance(message["content"], str):
891+
prompt += "[INST] " + message["content"]
892+
elif message["role"] == "assistant" and message["content"] is not None and isinstance(message["content"], str):
893+
prompt += " [/INST]" + message["content"] + eos
894+
prompt += " [/INST]"
895+
return ChatFormatterResponse(prompt=prompt, stop=stop)
894896

895897

896898
@register_chat_format("chatglm3")

tests/test_llama_chat_format.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,33 @@
11
import json
22

3+
import jinja2
4+
35
from llama_cpp import (
46
ChatCompletionRequestUserMessage,
57
)
8+
import llama_cpp.llama_types as llama_types
9+
import llama_cpp.llama_chat_format as llama_chat_format
10+
611
from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
712

13+
def test_mistral_instruct():
14+
chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
15+
chat_formatter = jinja2.Template(chat_template)
16+
messages = [
17+
llama_types.ChatCompletionRequestUserMessage(role="user", content="Instruction"),
18+
llama_types.ChatCompletionRequestAssistantMessage(role="assistant", content="Model answer"),
19+
llama_types.ChatCompletionRequestUserMessage(role="user", content="Follow-up instruction"),
20+
]
21+
response = llama_chat_format.format_mistral_instruct(
22+
messages=messages,
23+
)
24+
reference = chat_formatter.render(
25+
messages=messages,
26+
bos_token="<s>",
27+
eos_token="</s>",
28+
)
29+
assert response.prompt == reference
30+
831

932
mistral_7b_tokenizer_config = """{
1033
"add_bos_token": true,

0 commit comments

Comments
 (0)
0