10000 llama : validate seq id batch input (#13809) · ggml-org/llama.cpp@4f81b33 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4f81b33

Browse files
authored
llama : validate seq id batch input (#13809)
* llama : validate seq id batch input ggml-ci * cont : fix the fix ggml-ci
1 parent cdf94a1 commit 4f81b33

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

src/llama-context.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
693693

694694
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
695695

696+
// TODO: move the validation to the llama_batch_allocr
696697
if (batch.token) {
697698
for (int32_t i = 0; i < n_tokens; ++i) {
698699
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
699700
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
700701
return -1;
701702
}
703+
704+
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
705+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
706+
throw -1;
707+
}
702708
}
703709
}
704710

@@ -887,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
887893

888894
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
889895

896+
// TODO: move the validation to the llama_batch_allocr
890897
if (batch.token) {
891898
for (int64_t i = 0; i < n_tokens_all; ++i) {
892899
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
893900
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
894-
throw std::runtime_error("invalid token");
901+
return -1;
902+
}
903+
904+
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
905+
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
906+
return -1;
895907
}
896908
}
897909
}

0 commit comments

Comments
 (0)
0