8000 feat: Add Google's Gemma formatting via `chat_format="gemma"` (#1210) · coderonion/llama-cpp-python@251a8a2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 251a8a2

Browse files
alvarobarttabetlen
andauthored
feat: Add Google's Gemma formatting via chat_format="gemma" (abetlen#1210)
* Add Google's Gemma formatting via `chat_format="gemma"` * Replace `raise ValueError` with `logger.debug` Co-authored-by: Andrei <abetlen@gmail.com> --------- Co-authored-by: Andrei <abetlen@gmail.com>
1 parent eebb102 commit 251a8a2

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import llama_cpp.llama_types as llama_types
1515
import llama_cpp.llama_grammar as llama_grammar
1616

17+
from ._logger import logger
1718
from ._utils import suppress_stdout_stderr, Singleton
1819

1920
### Common Chat Templates and Special Tokens ###
@@ -993,6 +994,26 @@ def format_saiga(
993994
return ChatFormatterResponse(prompt=_prompt.strip())
994995

995996

997+
# Chat format for Google's Gemma models, see more details and available models:
998+
# https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b
999+
@register_chat_format("gemma")
1000+
def format_gemma(
1001+
messages: List[llama_types.ChatCompletionRequestMessage],
1002+
**kwargs: Any,
1003+
) -> ChatFormatterResponse:
1004+
system_message = _get_system_message(messages)
1005+
if system_message is not None and system_message != "":
1006+
logger.debug(
1007+
"`role='system'` messages are not allowed on Google's Gemma models."
1008+
)
1009+
_roles = dict(user="<start_of_turn>user\n", assistant="<start_of_turn>model\n")
1010+
_sep = "<end_of_turn>\n"
1011+
_messages = _map_roles(messages, _roles)
1012+
_messages.append((_roles["assistant"], None))
1013+
_prompt = _format_no_colon_single(system_message="", messages=_messages, sep=_sep)
1014+
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
1015+
1016+
9961017
# Tricky chat formats that require custom chat handlers
9971018

9981019

0 commit comments

Comments
 (0)
0