10000 Configurable Chat Formats (#711) · GPTprojects/llama-cpp-python@3bca770 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3bca770

Browse files
authored
Configurable Chat Formats (abetlen#711)
* Add configurable default chat completion format. * Remove chat_template file to avoid circular import * Update llama_types * Add chat format
1 parent a945404 commit 3bca770

File tree

2 files changed

+330
-19
lines changed

2 files changed

+330
-19
lines changed

llama_cpp/llama.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from . import llama_cpp
2525
from .llama_types import *
2626
from .llama_grammar import LlamaGrammar
27+
from . import llama_chat_format
2728

2829
import numpy as np
2930
import numpy.typing as npt
@@ -243,6 +244,8 @@ def __init__(
243244
lora_path: Optional[str] = None,
244245
# Backend Params
245246
numa: bool = False,
247+
# Chat Format Params
248+
chat_format: str = "llama-2",
246249
# Misc
247250
verbose: bool = True,
248251
# Extra Params
@@ -273,6 +276,7 @@ def __init__(
273276
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
274277
lora_path: Path to a LoRA file to apply to the model.
275278
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
279+
chat_format: String specifying the chat format to use when calling create_chat_completion.
276280
verbose: Print verbose output to stderr.
277281
kwargs: Unused keyword arguments (for additional backwards compatibility).
278282
@@ -388,6 +392,8 @@ def __init__(
388392

389393
if self.verbose:
390394
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
395+
396+
self.chat_format = chat_format
391397

392398
self._n_vocab = self.n_vocab()
393399
self._n_ctx = self.n_ctx()
@@ -1565,9 +1571,21 @@ def _convert_text_completion_chunks_to_chat(
15651571
],
15661572
}
15671573

1574+
def _convert_completion_to_chat(
1575+
self,
1576+
completion_or_chunks: Union[Completion, Iterator[CompletionChunk]],
1577+
stream: bool = False,
1578+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
1579+
if stream:
1580+
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
1581+
return self._convert_text_completion_chunks_to_chat(chunks)
1582+
else:
1583+
completion: Completion = completion_or_chunks # type: ignore
1584+
return self._convert_text_completion_to_chat(completion)
1585+
15681586
def create_chat_completion(
15691587
self,
1570-
messages: List[ChatCompletionMessage],
1588+
messages: List[ChatCompletionRequestMessage],
15711589
functions: Optional[List[ChatCompletionFunction]] = None,
15721590
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
15731591
temperature: float = 0.2,
@@ -1602,26 +1620,28 @@ def create_chat_completion(
16021620
Returns:
16031621
Generated chat completion or a stream of chat completion chunks.
16041622
"""
1605-
stop = (
1606-
stop if isinstance(stop, list 10000 ) else [stop] if isinstance(stop, str) else []
1607-
)
1608-
chat_history = "".join(
1609-
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
1610-
for message in messages
1623+
1624+
format = llama_chat_format.get_chat_format(self.chat_format)
1625+
result = format(
1626+
messages=messages,
16111627
)
1612-
PROMPT = chat_history + "### Assistant:"
1613-
PROMPT_STOP = ["### Assistant:", "### Human:"]
1614-
completion_or_chunks = self(
1615-
prompt=PROMPT,
1616-
stop=PROMPT_STOP + stop,
1628+
prompt = result.prompt
1629+
if result.stop is not None:
1630+
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
1631+
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
1632+
stop = stop + rstop
1633+
1634+
completion_or_chunks = self.create_completion(
1635+
prompt=prompt,
16171636
temperature=temperature,
16181637
top_p=top_p,
16191638
top_k=top_k,
16201639
stream=stream,
1640+
stop=stop,
16211641
max_tokens=max_tokens,
1622-
repeat_penalty=repeat_penalty,
16231642
presence_penalty=presence_penalty,
16241643
frequency_penalty=frequency_penalty,
1644+
repeat_penalty=repeat_penalty,
16251645
tfs_z=tfs_z,
16261646
mirostat_mode=mirostat_mode,
16271647
mirostat_tau=mirostat_tau,
@@ -1630,12 +1650,7 @@ def create_chat_completion(
16301650
logits_processor=logits_processor,
16311651
grammar=grammar,
16321652
)
1633-
if stream:
1634-
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
1635-
return self._convert_text_completion_chunks_to_chat(chunks)
1636-
else:
1637-
completion: Completion = completion_or_chunks # type: ignore
1638-
return self._convert_text_completion_to_chat(completion)
1653+
return self._convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
16391654

16401655
def __del__(self):
16411656
if hasattr(self, "model") and self.model is not None:
@@ -1675,6 +1690,8 @@ def __getstate__(self):
16751690
lora_path=self.lora_path,
16761691
# Backend Params
16771692
numa=self.numa,
1693+
# Chat Format Params
1694+
chat_format=self.chat_format,
16781695
# Misc
16791696
verbose=self.verbose,
16801697
)
@@ -1708,6 +1725,8 @@ def __setstate__(self, state):
17081725
lora_path=state["lora_path"],
17091726
# Backend Params
17101727
numa=state["numa"],
1728+
# Chat Format Params
1729+
chat_format=state["chat_format"],
17111730
# Misc
17121731
verbose=state["verbose"],
17131732
)

0 commit comments

Comments
 (0)
0