8000 llama : add llama_max_parallel_sequences() · ggml-org/llama.cpp@eda2e13 · GitHub
[go: up one dir, main page]

Skip to content

Commit eda2e13

Browse files
committed
llama : add llama_max_parallel_sequences()
ggml-ci
1 parent 44856a7 commit eda2e13

File tree

5 files changed

+16
-3
lines changed

5 files changed

+16
-3
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ extern "C" {
471471
LLAMA_API int64_t llama_time_us(void);
472472

473473
LLAMA_API size_t llama_max_devices(void);
474+
LLAMA_API size_t llama_max_parallel_sequences(void);
474475

475476
LLAMA_API bool llama_supports_mmap (void);
476477
LLAMA_API bool llama_supports_mlock (void);

src/llama-context.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@ llama_context::llama_context(
2525

2626
const auto & hparams = model.hparams;
2727

28-
cparams.n_seq_max = std::max(1u, params.n_seq_max);
28+
cparams.n_seq_max = std::max(1u, params.n_seq_max);
29+
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
30+
LLAMA_LOG_WARN("%s: n_seq_max (%d) is larger than the maximum supported (%d) - clamping\n", __func__, cparams.n_seq_max, LLAMA_MAX_PARALLEL_SEQUENCES);
31+
cparams.n_seq_max = LLAMA_MAX_PARALLEL_SEQUENCES;
32+
}
33+
2934
cparams.n_threads = params.n_threads;
3035
cparams.n_threads_batch = params.n_threads_batch;
3136
cparams.yarn_ext_factor = params.yarn_ext_factor;

src/llama-cparams.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
#include "llama-cparams.h"
2+
3+
size_t llama_max_parallel_sequences(void) {
4+
return LLAMA_MAX_PARALLEL_SEQUENCES;
5+
}

src/llama-cparams.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include <cstdint>
66

7+
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
8+
79
struct llama_cparams {
810
uint32_t n_ctx; // context size used during inference
911
uint32_t n_batch;

src/llama-kv-cells.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "llama.h"
4+
#include "llama-cparams.h"
45

56
#include <bitset>
67
#include <cassert>
@@ -119,7 +120,7 @@ class llama_kv_cells_unified {
119120
seq[i].reset(seq_id);
120121

121122
if (seq[i].none()) {
122-
pos[i]= -1;
123+
pos[i] = -1;
123124

124125
used--;
125126

@@ -267,6 +268,6 @@ class llama_kv_cells_unified {
267268
std::vector<llama_pos> shift;
268269

269270
// TODO: assert n_seq_max <= 64
270-
std::vector<std::bitset<64>> seq;
271+
std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
271272
};
272273

0 commit comments

Comments
 (0)
0