@@ -693,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
693
693
694
694
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
695
695
696
+ // TODO: move the validation to the llama_batch_allocr
696
697
if (batch.token ) {
697
698
for (int32_t i = 0 ; i < n_tokens; ++i) {
698
699
if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
699
700
LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
700
701
return -1 ;
701
702
}
703
+
704
+ if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
705
+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%d] = %d > %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
706
+ throw -1 ;
707
+ }
702
708
}
703
709
}
704
710
@@ -887,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
887
893
888
894
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
889
895
896
+ // TODO: move the validation to the llama_batch_allocr
890
897
if (batch.token ) {
891
898
for (int64_t i = 0 ; i < n_tokens_all; ++i) {
892
899
if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
893
900
LLAMA_LOG_ERROR (" %s: invalid token[%" PRId64 " ] = %d\n " , __func__, i, batch.token [i]);
894
- throw std::runtime_error (" invalid token" );
901
+ return -1 ;
902
+ }
903
+
904
+ if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
905
+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%" PRId64 " ] = %d >= %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
906
+ return -1 ;
895
907
}
896
908
}
897
909
}
0 commit comments