|
14 | 14 | import llama_cpp.llama_types as llama_types
|
15 | 15 | import llama_cpp.llama_grammar as llama_grammar
|
16 | 16 |
|
| 17 | +from ._logger import logger |
17 | 18 | from ._utils import suppress_stdout_stderr, Singleton
|
18 | 19 |
|
19 | 20 | ### Common Chat Templates and Special Tokens ###
|
@@ -993,6 +994,26 @@ def format_saiga(
|
993 | 994 | return ChatFormatterResponse(prompt=_prompt.strip())
|
994 | 995 |
|
995 | 996 |
|
| 997 | +# Chat format for Google's Gemma models, see more details and available models: |
| 998 | +# https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b |
| 999 | +@register_chat_format("gemma") |
| 1000 | +def format_gemma( |
| 1001 | + messages: List[llama_types.ChatCompletionRequestMessage], |
| 1002 | + **kwargs: Any, |
| 1003 | +) -> ChatFormatterResponse: |
| 1004 | + system_message = _get_system_message(messages) |
| 1005 | + if system_message is not None and system_message != "": |
| 1006 | + logger.debug( |
| 1007 | + "`role='system'` messages are not allowed on Google's Gemma models." |
| 1008 | + ) |
| 1009 | + _roles = dict(user="<start_of_turn>user\n", assistant="<start_of_turn>model\n") |
| 1010 | + _sep = "<end_of_turn>\n" |
| 1011 | + _messages = _map_roles(messages, _roles) |
| 1012 | + _messages.append((_roles["assistant"], None)) |
| 1013 | + _prompt = _format_no_colon_single(system_message="", messages=_messages, sep=_sep) |
| 1014 | + return ChatFormatterResponse(prompt=_prompt, stop=_sep) |
| 1015 | + |
| 1016 | + |
996 | 1017 | # Tricky chat formats that require custom chat handlers
|
997 | 1018 |
|
998 | 1019 |
|
|
0 commit comments