8000 server: fix/test add_generation_prompt (#13770) · ochafik/llama.cpp@d785f9c · GitHub
[go: up one dir, main page]

Skip to content

Commit d785f9c

Browse files
ochafikochafik
and
ochafik
authored
server: fix/test add_generation_prompt (ggml-org#13770)
Co-authored-by: ochafik <ochafik@google.com>
1 parent 4032ca4 commit d785f9c

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

tools/server/tests/unit/test_template.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,28 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
4747

4848
today_str = datetime.date.today().strftime(format)
4949
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
50+
51+
52+
@pytest.mark.parametrize("add_generation_prompt", [False, True])
53+
@pytest.mark.parametrize("template_name,expected_generation_prompt", [
54+
("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"),
55+
])
56+
def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
57+
global server
58+
server.jinja = True
59+
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
60+
server.start(timeout_seconds=TIMEOUT_SERVER_START)
61+
62+
res = server.make_request("POST", "/apply-template", data={
63+
"messages": [
64+
{"role": "user", "content": "What is today?"},
65+
],
66+
"add_generation_prompt": add_generation_prompt,
67+
})
68+
assert res.status_code == 200
69+
prompt = res.body["prompt"]
70+
71+
if add_generation_prompt:
72+
assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
73+
else:
74+
assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"

tools/server/utils.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,7 @@ static json oaicompat_chat_params_parse(
731731
inputs.grammar = grammar;
732732
inputs.use_jinja = opt.use_jinja;
733733
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
734+
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
734735
inputs.reasoning_format = opt.reasoning_format;
735736
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
736737
throw std::runtime_error("Cannot use custom grammar constraints with tools.");

0 commit comments

Comments
 (0)
0