@@ -3341,6 +3341,37 @@ struct server_context {
3341
3341
common_set_adapter_lora (ctx, slot_batched->lora );
3342
3342
}
3343
3343
3344
+ const bool do_encode = (params_base.embedding || params_base.reranking );
3345
+
3346
+ // pad the batch so that batch.n_tokens >= n_slots
3347
+ // TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
3348
+ if (do_encode) {
3349
+ const int n_slots = slots.size ();
3350
+
3351
+ if (batch.n_tokens < n_slots) {
3352
+ std::set<llama_seq_id> seq_ids;
3353
+ for (int j = 0 ; j < batch.n_tokens ; ++j) {
3354
+ seq_ids.insert (batch.seq_id [j][0 ]);
3355
+ }
3356
+
3357
+ // find unused sequence id
3358
+ llama_seq_id seq_id = -1 ;
3359
+ for (int i = 0 ; i < n_slots; ++i) {
3360
+ if (seq_ids.find (i) == seq_ids.end ()) {
3361
+ seq_id = i;
3362
+ }
3363
+ }
3364
+
3365
+ const int n_add = n_slots - batch.n_tokens ;
3366
+
3367
+ SRV_WRN (" adding %d dummy tokens to the batch, seq_id = %d\n " , n_add, seq_id);
3368
+
3369
+ for (int j = 0 ; j < n_add; ++j) {
3370
+ common_batch_add (batch, 0 , j, { seq_id }, false );
3371
+ }
3372
+ }
3373
+ }
3374
+
3344
3375
// process the created batch of tokens
3345
3376
for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
3346
3377
const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
@@ -3357,7 +3388,7 @@ struct server_context {
3357
3388
3358
3389
int ret = 0 ;
3359
3390
3360
- if (params_base. embedding || params_base. reranking ) {
3391
+ if (do_encode ) {
3361
3392
ret = llama_encode (ctx, batch_view);
3362
3393
} else {
3363
3394
ret = llama_decode (ctx, batch_view);
0 commit comments