8000 server : fix first message identification (#13634) · baseweight/llama.cpp@e99ceff · GitHub
[go: up one dir, main page]

Skip to content

Commit e99ceff

Browse files
doringemanp1-0tr
authored andcommitted
server : fix first message identification (ggml-org#13634)
* server : fix first message identification When using the OpenAI SDK (https://github.com/openai/openai-node/blob/master/src/lib/ChatCompletionStream.ts#L623-L626) we noticed that the expected assistant role is missing in the first streaming message. Fix this by correctly checking for the first message. Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com> Signed-off-by: Dorin Geman <dorin.geman@docker.com> * server : Fix checks for first role message for stream=True Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com> Signed-off-by: Dorin Geman <dorin.geman@docker.com> --------- Signed-off-by: Dorin Geman <dorin.geman@docker.com> Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
1 parent 7b46cf2 commit e99ceff

File tree

2 files changed

+53
-21
lines changed

2 files changed

+53
-21
lines changed

tools/server/server.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,7 @@ struct server_task_result_cmpl_partial : server_task_result {
951951
}
952952

953953
json to_json_oaicompat_chat() {
954-
bool first = n_decoded == 0;
954+
bool first = n_decoded == 1;
955955
std::time_t t = std::time(0);
956956
json choices;
957957

@@ -962,15 +962,18 @@ struct server_task_result_cmpl_partial : server_task_result {
962962
{"delta", json{{"role", "assistant"}}}}});
963963
} else {
964964
// We have to send this as two updates to conform to openai behavior
965+
// initial_ret is the role message for stream= 8000 True
965966
json initial_ret = json{{"choices", json::array({json{
966967
{"finish_reason", nullptr},
967968
{"index", 0},
968969
{"delta", json{
969-
{"role", "assistant"}
970+
{"role", "assistant"},
971+
{"content", ""}
970972
}}}})},
971973
{"created", t},
972974
{"id", oaicompat_cmpl_id},
973975
{"model", oaicompat_model},
976+
{"system_fingerprint", build_info},
974977
{"object", "chat.completion.chunk"}};
975978

976979
json second_ret = json{
@@ -982,8 +985,19 @@ struct server_task_result_cmpl_partial : server_task_result {
982985
{"created", t},
983986
{"id", oaicompat_cmpl_id},
984987
{"model", oaicompat_model},
988+
{"system_fingerprint", build_info},
985989
{"object", "chat.completion.chunk"}};
986990

991+
if (prob_output.probs.size() > 0) {
992+
second_ret["choices"][0]["logprobs"] = json{
993+
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
994+
};
995+
}
996+
997+
if (timings.prompt_n >= 0) {
998+
second_ret.push_back({"timings", timings.to_json()});
999+
}
1000+
9871001
return std::vector<json>({initial_ret, second_ret});
9881002
}
9891003
} else {

tools/server/tests/unit/test_chat_completion.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,14 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
7171
})
7272
content = ""
7373
last_cmpl_id = None
74-
for data in res:
74+
for i, data in enumerate(res):
7575
choice = data["choices"][0]
76+
if i == 0:
77+
# Check first role message for stream=True
78+
assert choice["delta"]["content"] == ""
79+
assert choice["delta"]["role"] == "assistant"
80+
else:
81+
assert "role" not in choice["delta"]
7682
assert data["system_fingerprint"].startswith("b")
7783
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
7884
if last_cmpl_id is None:
@@ -242,12 +248,18 @@ def test_chat_completion_with_timings_per_token():
242248
"stream": True,
243249
"timings_per_token": True,
244250
})
245-
for data in res:
246-
assert "timings" in data
247-
assert "prompt_per_second" in data["timings"]
248-
assert "predicted_per_second" in data["timings"]
249-
assert "predicted_n" in data["timings"]
250-
assert data["timings"]["predicted_n"] <= 10
251+
for i, data in enumerate(res):
252+
if i == 0:
253+
# Check first role message for stream=True
254+
assert data["choices"][0]["delta"]["content"] == ""
255+
assert data["choices"][0]["delta"]["role"] == "assistant"
256+
else:
257+
assert "role" not in data["choices"][0]["delta"]
258+
assert "timings" in data
259+
assert "prompt_per_second" in data["timings"]
260+
assert "predicted_per_second" in data["timings"]
261+
assert "predicted_n" in data["timings"]
262+
assert data["timings"]["predicted_n"] <= 10
251263

252264

253265
def test_logprobs():
@@ -295,17 +307,23 @@ def test_logprobs_stream():
295307
)
296308
output_text = ''
297309
aggregated_text = ''
298-
for data in res:
310+
for i, data in enumerate(res):
299311
choice = data.choices[0]
300-
if choice.finish_reason is None:
301-
if choice.delta.content:
302-
output_text += choice.delta.content
303-
assert choice.logprobs is not None
304-
assert choice.logprobs.content is not None
305-
for token in choice.logprobs.content:
306-
aggregated_text += token.token
307-
assert token.logprob <= 0.0
308-
assert token.bytes is not None
309-
assert token.top_logprobs is not None
310-
assert len(token.top_logprobs) > 0
312+
if i == 0:
313+
# Check first role message for stream=True
314+
assert choice.delta.content == ""
315+
assert choice.delta.role == "assistant"
316+
else:
317+
assert choice.delta.role is None
318+
if choice.finish_reason is None:
319+
if choice.delta.content:
320+
output_text += choice.delta.content
321+
assert choice.logprobs is not None
322+
assert choice.logprobs.content is not None
323+
for token in choice.logprobs.content:
324+
aggregated_text += token.token
325+
assert token.logprob <= 0.0
326+
assert token.bytes is not None
327+
assert token.top_logprobs is not None
328+
assert len(token.top_logprobs) > 0
311329
assert aggregated_text == output_text

0 commit comments

Comments
 (0)
0