8000 Automatically set chat format from gguf (#1110) · devilcoder01/llama-cpp-python@da003d8 · GitHub
[go: up one dir, main page]

Skip to content

Commit da003d8

Browse files
authored
Automatically set chat format from gguf (abetlen#1110)
* Use jinja formatter to load chat format from gguf * Fix off-by-one error in metadata loader * Implement chat format auto-detection
1 parent 059f6b3 commit da003d8

File tree

4 files changed

+68
-7
lines changed

4 files changed

+68
-7
lines changed

llama_cpp/_internals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,13 @@ def metadata(self) -> Dict[str, str]:
216216
for i in range(llama_cpp.llama_model_meta_count(self.model)):
217217
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
218218
if nbytes > buffer_size:
219-
buffer_size = nbytes
219+
buffer_size = nbytes + 1
220220
buffer = ctypes.create_string_buffer(buffer_size)
221221
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
222222
key = buffer.value.decode("utf-8")
223223
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
224224
if nbytes > buffer_size:
225-
buffer_size = nbytes
225+
buffer_size = nbytes + 1
226226
buffer = ctypes.create_string_buffer(buffer_size)
227227
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
228228
value = buffer.value.decode("utf-8")

llama_cpp/llama.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
# Backend Params
8888
numa: bool = False,
8989
# Chat Format Params
90-
chat_format: str = "llama-2",
90+
chat_format: Optional[str] = None,
9191
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
9292
# Misc
9393
verbose: bool = True,
@@ -343,6 +343,41 @@ def __init__(
343343
if self.verbose:
344344
print(f"Model metadata: {self.metadata}", file=sys.stderr)
345345

346+
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
347+
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
348+
349+
if chat_format is not None:
350+
self.chat_format = chat_format
351+
if self.verbose:
352+
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
353+
else:
354+
template = self.metadata["tokenizer.chat_template"]
355+
try:
356+
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
357+
except:
358+
eos_token_id = self.token_eos()
359+
try:
360+
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
361+
except:
362+
bos_token_id = self.token_bos()
363+
364+
eos_token = self.detokenize([eos_token_id]).decode("utf-8")
365+
bos_token = self.detokenize([bos_token_id]).decode("utf-8")
366+
367+
if self.verbose:
368+
print(f"Using chat template: {template}", file=sys.stderr)
369+
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
370+
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
371+
372+
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
373+
template=template,
374+
eos_token=eos_token,
375+
bos_token=bos_token
376+
).to_chat_handler()
377+
378+
if self.chat_format is None and self.chat_handler is None:
379+
self.chat_format = "llama-2"
380+
346381
@property
347382
def ctx(self) -> llama_cpp.llama_context_p:
348383
assert self._ctx.ctx is not None

llama_cpp/llama_chat_format.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,20 @@
1414

1515
from ._utils import suppress_stdout_stderr, Singleton
1616

17+
### Common Chat Templates and Special Tokens ###
18+
19+
# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
20+
CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
21+
CHATML_BOS_TOKEN = "<s>"
22+
CHATML_EOS_TOKEN = "<|im_end|>"
23+
24+
# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
25+
MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
26+
MISTRAL_INSTRUCT_BOS_TOKEN = "<s>"
27+
MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
28+
29+
30+
### Chat Completion Handler ###
1731

1832
class LlamaChatCompletionHandler(Protocol):
1933
"""Base Protocol for a llama chat completion handler.
@@ -118,7 +132,6 @@ def decorator(f: LlamaChatCompletionHandler):
118132

119133
### Chat Formatter ###
120134

121-
122135
@dataclasses.dataclass
123136
class ChatFormatterResponse:
124137
"""Dataclass that stores completion parameters for a given chat format and
@@ -440,7 +453,20 @@ def hf_tokenizer_config_to_chat_completion_handler(
440453
return chat_formatter_to_chat_completion_handler(chat_formatter)
441454

442455

456+
def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]:
457+
if "tokenizer.chat_template" not in metadata:
458+
return None
459+
460+
if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE:
461+
return "chatml"
462+
463+
if metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE:
464+
return "mistral-instruct"
465+
466+
return None
467+
443468
### Utility functions for formatting chat prompts ###
469+
# TODO: Replace these with jinja2 templates
444470

445471

446472
def _get_system_message(
@@ -929,7 +955,6 @@ def format_openchat(
929955
_prompt = _format_chatml(system_message, _messages, _sep)
930956
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
931957

932-
933958
# Chat format for Saiga models, see more details and available models:
934959
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
935960
@register_chat_format("saiga")
@@ -951,6 +976,7 @@ def format_saiga(
951976
_prompt += "<s>bot"
952977
return ChatFormatterResponse(prompt=_prompt.strip())
953978

979+
# Tricky chat formats that require custom chat handlers
954980

955981
@register_chat_completion_handler("functionary")
956982
def functionary_chat_handler(

llama_cpp/server/settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ class ModelSettings(BaseSettings):
113113
description="Enable NUMA support.",
114114
)
115115
# Chat Format Params
116-
chat_format: str = Field(
117-
default="llama-2",
116+
chat_format: Optional[str] = Field(
117+
default=None,
118118
description="Chat format to use.",
119119
)
120120
clip_model_path: Optional[str] = Field(

0 commit comments

Comments
 (0)
0