8000 `server`: fix tool-call of DeepSeek R1 Qwen, return reasoning_content (Command 7RB & DeepSeek R1) unless `--reasoning-format none` by ochafik · Pull Request #11607 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

server: fix tool-call of DeepSeek R1 Qwen, return reasoning_content (Command 7RB & DeepSeek R1) unless --reasoning-format none #11607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 94 commits into from
Feb 13, 2025
Merged
Changes from 1 commit
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
d3b60b8
minja: enhance backfill of templates w/o tools description (use examp…
Feb 3, 2025
87de852
pass vocab to common_chat_params_init
Feb 3, 2025
130ca22
DeepSeek R1: parse thoughts / return in separate field in API (non st…
Feb 3, 2025
04d511b
Avoid double bos w/ jinja
Feb 3, 2025
2834587
server/oai: ensure content is null when there are tool calls
Feb 3, 2025
c80cb30
update logs
Feb 3, 2025
0871628
rename tests
Feb 3, 2025
73d08d4
tool-call: allow `--jinja --chat-template chatml`
Feb 3, 2025
04be723
tool-call: fix command-r7b parsing when response is multiline
Feb 3, 2025
ae9d581
tool-calls: add DeepSeek R1 Qwen 7B to server test_hello_world
Feb 3, 2025
19bea4e
tell DS R1 not to overthink (weather test)
Feb 3, 2025
5e6f2a2
add deepseek models to server tool call section in readme
Feb 3, 2025
1e9acd2
tool-call: allow `--jinja --chat-template chatml`
Feb 3, 2025
77ae97e
Update test_tool_call.py
Feb 3, 2025
a76073c
minimize diffs
Feb 3, 2025
cf83623
fix typo
Feb 3, 2025
5d18d76
fix double bos issue (drop bos/eos tokens from jinja template)
Feb 3, 2025
aa98e59
fix bad merge
Feb 3, 2025
2b3c482
fix build / rm diff
Feb 3, 2025
4cb0e1d
Merge branch 'jinja-chatml' into r1-toolcall
Feb 3, 2025
b2dd490
add missing try catch around jinja parsing to default to chatml
Feb 3, 2025
08271b5
Merge branch 'jinja-chatml' into r1-toolcall
Feb 3, 2025
df3474e
tool-calls: r1: add missing <|tool▁calls▁end|> to grammar!
Feb 3, 2025
c397bd1
tweak delta logic
Feb 3, 2025
569610e
tool-calls: accommodate variety of wrong tool call opening tags both …
Feb 3, 2025
d73448d
Simplify default chatml logic
Feb 3, 2025
0be7f65
Merge branch 'jinja-chatml' into r1-toolcall
Feb 3, 2025
7dc271f
tool-calls: add deepseek r1 template + accommodate broken official te…
Feb 3, 2025
c6214ee
rm unneeded vocab
Feb 3, 2025
1c302e1
simpler hacky fixes for original broken template (+ fix minja example…
Feb 3, 2025
108da90
sync: minja https://github.com/google/minja/pull/46
Feb 3, 2025
bc6d910
Merge branch 'master' into r1-toolcall
Feb 3, 2025
11c1f0c
actually we want eos_token in the template to infer tool call example…
Feb 3, 2025
30ea359
update to minja's new api
Feb 3, 2025
bbd45bf
sync: minja
Feb 4, 2025
bff549d
simplify hack to fix original template's backfill from minja
Feb 4, 2025
ce28224
tool-call: r1: add one more trigger approx "<|tool calls begin|>"
Feb 4, 2025
e84ee88
r1: fix inadvertent newline in grammar before <|tool▁call▁end|>
Feb 4, 2025
18a11f4
tool-call: r1: fix grammar
Feb 4, 2025
9a6847c
move trigger_words init inside non-llguidance branch
Feb 4, 2025
a682d12
fix / test parsing of r1 parser
Feb 4, 2025
f0154a6
Fix / test models/templates/llama-cpp-deepseek-r1.jinja
Feb 4, 2025
326e700
update test_calc_result
Feb 4, 2025
78b47bb
fix test_calc_result
Feb 4, 2025
86994db
fix spaces
Feb 4, 2025
09caa63
`sync`: minja
Feb 4, 2025
b152729
Update test-chat.cpp
Feb 4, 2025
56a14dd
fix mistral chat test: need empty tokens
Feb 4, 2025
f12e350
Update chat.cpp
Feb 4, 2025
d43e4f6
Merge branch 'sync-minja-4' into r1-toolcall
Feb 4, 2025
812544a
server: check that content is null when we get tool_calls
Feb 4, 2025
d44eb95
tool-call: ensure we don't return content when there are tool calls /…
Feb 4, 2025
b6e14a4
fix mistral expectation
Feb 4, 2025
1f5ec59
ensure deepseek r1 thoughts parsed even w/o tool calls
Feb 4, 2025
438ce0b
fix test-chat
Feb 4, 2025
21f2071
Update chat.cpp
Feb 4, 2025
b5b117f
Merge branch 'sync-minja-4' into r1-toolcall
Feb 4, 2025
0db9881
Fix r1 grammar since we made <|tool▁calls▁begin|> optional (triggerin…
Feb 4, 2025
d1b6691
r1: revert making <|tool▁calls▁begin|> optional as somehow sampling t…
Feb 4, 2025
39c1d81
return thoughts in reasoning_content field
Feb 4, 2025
b2d1728
update readme section about common model tool call formats
Feb 4, 2025
933f7a1
Merge branch 'master' into r1-toolcall
Feb 4, 2025
5d60ceb
Update test_tool_call.py
Feb 4, 2025
1f1f06a
Merge branch 'master' into r1-toolcall
ochafik Feb 5, 2025
9d7c3cc
--think to force any model to return reasoning_content (or just parse…
Feb 5, 2025
d20c2ce
Merge branch 'r1-toolcall' of github.com:ochafik/llama.cpp into r1-to…
Feb 5, 2025
f3e9f8b
fix test_thoughts
Feb 5, 2025
3841a16
fix compiler warning about parens
Feb 5, 2025
e6d9b52
align Command R7B w/ --think / reasoning_content behaviour
Feb 5, 2025
39b50c3
Update README.md
Feb 5, 2025
0917e0a
fix --think arg env
Feb 5, 2025
098629d
disable some failing chatml tests
Feb 5, 2025
33efcb3
Update README.md
Feb 5, 2025
994301d
use existing string_strip
Feb 5, 2025
d1a0640
revert tool example backfill change - command 7rb just needs the righ…
Feb 5, 2025
cc2c712
Merge remote-tracking branch 'origin/master' into r1-toolcall
Feb 8, 2025
c0f972b
Use --reasoning-format, remove forced thinking for now
Feb 8, 2025
af63886
return reasoning_content before content
Feb 8, 2025
a59fde2
update model template / format mapping
Feb 8, 2025
b829cab
fix test-chat
Feb 8, 2025
95cddfd
rm thoughts from generic parser
Feb 9, 2025
e598e7a
sync: minja (https://github.com/google/minja/pull/52)
Feb 9, 2025
91542ca
tool-calls: allow r1 output to miss <think> opening tag (since latest…
Feb 9, 2025
8d82be9
sync: minja (https://github.com/ggerganov/llama.cpp/pull/11774)
Feb 9, 2025
30dcfaa
rm wrong warning in command-r parser (when normal text)
Feb 9, 2025
e1bff8f
update deepseek r1 templates (+ put update commands in ./scripts/get_…
Feb 9, 2025
a29dc92
fix server test_tool_calls.py
Feb 9, 2025
ea2f41e
add models/templates/README.md
Feb 9, 2025
8409bf1
fix test_calc_result & test_thoughts
Feb 9, 2025
01db429
fix test-chat (update delta to latest r1 template change)
Feb 9, 2025
37a4bb2
Merge remote-tracking branch 'origin/master' into r1-toolcall
Feb 12, 2025
d52579a
prefer json::at to operator[] in chat.cpp
Feb 13, 2025
4700245
Merge remote-tracking branch 'origin/master' into r1-toolcall
Feb 13, 2025
043cb99
Apply suggestions from code review
ochafik Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
prefer json::at to operator[] in chat.cpp
  • Loading branch information
Olivier Chafik committed Feb 13, 2025
commit d52579a9b5e3ae682ea31cf0dad32e92a822ee2b
100 changes: 50 additions & 50 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
result.role = "assistant";
const auto process_tool_calls = [&](const json & tool_calls) {
for (const auto & tool_call : tool_calls) {
const auto & arguments = tool_call["arguments"];
const auto & arguments = tool_call.at("arguments");
result.tool_calls.push_back({
tool_call["name"],
tool_call.at("name"),
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
tool_call.contains("id") ? tool_call["id"] : "",
tool_call.contains("id") ? tool_call.at("id") : "",
});
}
};
Expand All @@ -163,7 +163,7 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in

static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
for (const auto & tool : tools) {
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
continue;
}
Expand Down Expand Up @@ -198,27 +198,27 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp

auto tool_call_schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
const auto & function = tool.at("function");
auto tool_schema = json {
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function["name"]},
{"const", function.at("name")},
}},
{"arguments", function["parameters"]},
{"arguments", function.at("parameters")},
}},
{"required", json::array({"name", "arguments"})},
};
if (function.contains("description")) {
tool_schema["description"] = function["description"];
tool_schema["description"] = function.at("description");
}
if (inputs.parallel_tool_calls) {
tool_schema["properties"]["id"] = {
tool_schema.at("properties")["id"] = {
{"type", "string"},
{"minLength", 4},
};
tool_schema["required"].push_back("id");
tool_schema.at("required").push_back("id");
}
tool_call_schemas.emplace_back(tool_schema);
});
Expand Down Expand Up @@ -283,21 +283,21 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
common_chat_msg result;
result.role = "assistant";
if (data.contains("tool_calls")) {
for (const auto & tool_call : data["tool_calls"]) {
for (const auto & tool_call : data.at("tool_calls")) {
result.tool_calls.push_back({
tool_call["name"],
tool_call["arguments"].dump(),
tool_call.contains("id") ? tool_call["id"] : "",
tool_call.at("name"),
tool_call.at("arguments").dump(),
tool_call.contains("id") ? tool_call.at("id") : "",
});
}
} else if (data.contains("tool_call")) {
result.tool_calls.push_back({
data["tool_call"]["name"],
data["tool_call"]["arguments"].dump(),
data.at("tool_call").at("name"),
data.at("tool_call").at("arguments").dump(),
/* id= */ "",
});
} else if (data.contains("response")) {
const auto & response = data["response"];
const auto & response = data.at("response");
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
}
return result;
Expand All @@ -309,17 +309,17 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
const auto & function = tool.at("function");
schemas.push_back({
{"type", "object"},
{"properties", {
// Important note: the model is probably trained to take a JSON stringified arguments value.
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
{"name", {
{"type", "string"},
{"const", function["name"]},
{"const", function.at("name")},
}},
{"arguments", function["parameters"]},
{"arguments", function.at("parameters")},
{"id", {
{"type", "string"},
// Nemo's template expects a 9-character alphanumeric ID.
Expand Down Expand Up @@ -354,7 +354,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
const auto & function = tool.at("function");
schemas.push_back({
{"type", "object"},
{"properties", {
Expand All @@ -365,9 +365,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
}},
{"tool_name", {
{"type", "string"},
{"const", function["name"]},
{"const", function.at("name")},
}},
{"parameters", function["parameters"]},
{"parameters", function.at("parameters")},
}},
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
});
Expand All @@ -392,11 +392,11 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
};
auto adjusted_messages = json::array();
for (const auto & msg : inputs.messages) {
auto has_reasoning_content = msg.contains("reasoning_content") && msg["reasoning_content"].is_string();
auto has_tool_calls = msg.contains("tool_calls") && msg["tool_calls"].is_array();
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
if (has_reasoning_content && has_tool_calls) {
auto adjusted_message = msg;
adjusted_message["tool_plan"] = msg["reasoning_content"];
adjusted_message["tool_plan"] = msg.at("reasoning_content");
adjusted_message.erase("reasoning_content");
adjusted_messages.push_back(adjusted_message);
} else {
Expand Down Expand Up @@ -433,9 +433,9 @@ static common_chat_msg common_chat_parse_command_r7b(const std::string & input,
auto actions = json::parse(actions_str);
for (const auto & action : actions) {
result.tool_calls.push_back({
/* .name = */ action["tool_name"],
/* .arguments = */ action["parameters"].dump(),
/* .id = */ action["tool_call_id"],
/* .name = */ action.at("tool_name"),
/* .arguments = */ action.at("parameters").dump(),
/* .id = */ action.at("tool_call_id"),
});
}
} else if (std::regex_match(rest, match, response_regex)) {
Expand All @@ -448,7 +448,7 @@ static common_chat_msg common_chat_parse_command_r7b(const std::string & input,
}

static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
}
const auto & parameters_properties = parameters.at("properties");
Expand Down Expand Up @@ -502,9 +502,9 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
};

foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);

// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
Expand Down Expand Up @@ -585,9 +585,9 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
auto args_rule = builder.add_schema(name + "-args", parameters);
tool_rules.push_back(builder.add_rule(name + "-call",
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
Expand Down Expand Up @@ -678,15 +678,15 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
const auto & function = tool.at("function");
schemas.push_back({
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function["name"]},
{"const", function.at("name")},
}},
{"arguments", function["parameters"]},
{"arguments", function.at("parameters")},
}},
{"required", json::array({"name", "arguments", "id"})},
});
Expand Down Expand Up @@ -724,9 +724,9 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
Expand Down Expand Up @@ -806,9 +806,9 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
const auto & parameters = function["parameters"];
std::string name = function["name"];
const auto & function = tool.at("function");
const auto & parameters = function.at("parameters");
std::string name = function.at("name");
if (name == "python" || name == "ipython") {
if (!parameters.contains("type")) {
throw std::runtime_error("Missing type in python tool");
Expand Down Expand Up @@ -879,9 +879,9 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_schema(name + "-call", {
{"type", "object"},
Expand Down Expand Up @@ -929,9 +929,9 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
if (!parse_json(it, end, call)) {
throw std::runtime_error("Failed to parse json tool call");
}
const auto & arguments = call["arguments"];
const auto & arguments = call.at("arguments");
result.tool_calls.push_back({
call["name"],
call.at("name"),
arguments.dump(),
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
/* id= */ "",
Expand Down
Loading
0