8000 support `min_length` for turbomind backend (#961) · InternLM/lmdeploy@3f1c691 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3f1c691

Browse files
committed
support min_length for turbomind backend (#961)
* support min_length * fix lint * disable by default * fix step * use min_new_tokens
1 parent 9088817 commit 3f1c691

File tree

7 files changed

+47
-15
lines changed

7 files changed

+47
-15
lines changed

lmdeploy/messages.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class GenerationConfig:
3030
random_seed (int): Seed used when sampling a token
3131
stop_words (List[str]): Words that stop generating further tokens
3232
bad_words (List[str]): Words that the engine will never generate
33+
min_new_tokens (int): The minimum numbers of tokens to generate,
34+
ignoring the number of tokens in the prompt.
3335
"""
3436

3537
n: int = 1
@@ -42,6 +44,7 @@ class GenerationConfig:
4244
random_seed: int = None
4345
stop_words: List[str] = None
4446
bad_words: List[str] = None
47+
min_new_tokens: int = None
4548

4649

4750
@dataclass
@@ -65,7 +68,7 @@ def From(gen_config: GenerationConfig, tokenizer: Tokenizer):
6568
>>> tokenizer = Tokenizer('internlm/internlm-chat-7b')
6669
>>> gen_config = GenerationConfig(stop_words=['<eoa>'])
6770
>>> gen_config = EngineGenerationConfig.From(gen_config, tokenizer)
68-
""" # noqa E501
71+
""" # noqa E501
6972

7073
def special_word_token_ids(words):
7174
if words is not None:

lmdeploy/turbomind/turbomind.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,10 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
648648
inputs['input_embeddings'] = input_embeddings
649649
inputs['input_embedding_ranges'] = input_embedding_ranges
650650

651+
if gen_config.min_new_tokens is not None:
652+
inputs['min_length'] = _broadcast_np(gen_config.min_new_tokens,
653+
np.int32)
654+
651655
bad_words = []
652656
if gen_config.bad_words is not None:
653657
bad_words.extend(gen_config.bad_words)

src/turbomind/kernels/sampling_penalty_kernels.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -497,9 +497,8 @@ __global__ void batchApplyMinLengthPenalty(T* logits,
497497
const int vocab_size_padded)
498498
{
499499
int bid = threadIdx.x + blockIdx.x * blockDim.x; // batch index
500-
// We need +1 because sequence_lengths = max_input_length + num_gen_tokens - 1,
501-
// which is equal to the length of k/v caches.
502-
if (sequence_lengths[bid] + 1 - max_input_length < min_lengths[bid]) {
500+
// In decoder, sequence_lengths means length of sequence that has kv cache already computed
501+
if (sequence_lengths[bid] + 1 < min_lengths[bid]) {
503502
T mask_val = (std::is_same<T, half>::value) ? -65504.0f : -FLT_MAX;
504503
logits[bid * vocab_size_padded + end_ids[bid]] = mask_val;
505504
}

src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso
4545
repetition_penalty_ = (float*)std::realloc((void*)repetition_penalty_, batch_size * sizeof(float));
4646
min_lengths_ = (int*)std::realloc((void*)min_lengths_, batch_size * sizeof(int));
4747
skip_decode_ = (bool*)std::realloc((void*)skip_decode_, batch_size * sizeof(bool));
48+
context_length_ = (int*)std::realloc((void*)context_length_, batch_size * sizeof(int));
4849

4950
is_allocate_buffer_ = true;
5051
}
@@ -63,6 +64,7 @@ void BaseSamplingLayer<T>::freeBuffer()
6364
std::free(repetition_penalty_);
6465
std::free(min_lengths_);
6566
std::free(skip_decode_);
8000 67+
std::free(context_length_);
6668
is_allocate_buffer_ = false;
6769
}
6870
}
@@ -161,16 +163,23 @@ void BaseSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
161163
repetition_penalty_type_ = RepetitionPenaltyType::None;
162164
}
163165

164-
const int default_min_length = 0;
165-
Tensor min_lengths = runtime_args->at("min_length", Tensor(MEMORY_CPU, TYPE_INT32, {1}, &default_min_length));
166-
if (min_lengths.size() == 1) {
167-
int minlen = min_lengths.getVal<int>();
168-
deviceFill(min_lengths_buf_, batch_size, minlen, stream_);
169-
std::fill_n(min_lengths_, batch_size, minlen);
166+
// min_length
167+
if (runtime_args->isExist("min_length")) {
168+
Tensor min_lengths = runtime_args->at("min_length");
169+
Tensor context_lengths = runtime_args->at("context_length");
170+
Tensor prompt_lengths = runtime_args->at("prompt_length");
171+
auto p1 = min_lengths.getPtr<int>();
172+
auto p2 = prompt_lengths.getPtr<int>();
173+
for (int i = 0; i < batch_size; i++) {
174+
min_lengths_[i] = p1[i] + p2[i];
175+
}
176+
cudaAutoCpy(min_lengths_buf_, min_lengths_, batch_size, stream_);
177+
std::copy_n(context_lengths.getPtr<int>(), batch_size, context_length_);
170178
}
171179
else {
172-
cudaAutoCpy(min_lengths_buf_, min_lengths.getPtr<int>(), batch_size, stream_);
173-
std::copy_n(min_lengths.getPtr<int>(), batch_size, min_lengths_);
180+
std::fill_n(min_lengths_, batch_size, 0);
181+
deviceFill(min_lengths_buf_, batch_size, 0, stream_);
182+
std::fill_n(context_length_, batch_size, 0);
174183
}
175184
}
176185

@@ -300,10 +309,12 @@ void BaseSamplingLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_t
300309
}
301310
}
302311

303-
const int num_generated_tokens = step - max_input_length;
304-
const int* min_lengths = min_lengths_ + ite * local_batch_size;
312+
const int num_generated_tokens = step - max_input_length;
313+
const int* min_lengths = min_lengths_ + ite * local_batch_size;
314+
std::vector<int> index(local_batch_size);
315+
std::iota(index.begin(), index.end(), 0);
305316
const bool invoke_min_length_penalty = std::any_of(
306-
min_lengths, min_lengths + local_batch_size, [&](int min_length) { return min_length > num_generated_tokens; });
317+
index.begin(), index.end(), [&](int i) { return min_lengths[i] > context_length_[i] + num_generated_tokens; });
307318
if (invoke_min_length_penalty) {
308319
FT_CHECK_WITH_INFO(input_tensors->isExist("end_id"), "Need end_id to apply min length penlaty");
309320
invokeMinLengthPenalty(logits,

src/turbomind/layers/sampling_layers/BaseSamplingLayer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer {
4747
int* min_lengths_ = nullptr;
4848
bool* skip_decode_ = nullptr;
4949
bool skip_any_ = false;
50+
int* context_length_ = nullptr;
5051

5152
RepetitionPenaltyType repetition_penalty_type_ = RepetitionPenaltyType::None;
5253

src/turbomind/models/llama/LlamaBatch.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
328328
}
329329

330330
// total context length (history + input)
331+
state.h_prompt_length[idx] = output_ids - output_ids_base;
331332
state.h_context_length[idx] = output_ids - output_ids_base;
332333
state.h_finished[idx] = false;
333334

@@ -698,6 +699,7 @@ void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchSta
698699
}
699700

700701
for (const auto& [s, d, si, di] : desc) {
702+
d->h_prompt_length[di] = s->h_prompt_length[si];
701703
d->h_context_length[di] = s->h_context_length[si];
702704
d->h_finished[di] = s->h_finished[si];
703705
d->h_rope_theta[di] = s->h_rope_theta[si];
@@ -774,6 +776,7 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
774776
h_bad_words_ =
775777
(int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true);
776778

779+
h_min_length_ = (int*)allocator_->reMalloc(h_min_length_, sizeof(int) * max_batch_size, true, true);
777780
h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true);
778781
h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true);
779782
h_temperature_ = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true);
@@ -796,6 +799,7 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
796799
sampling_params_ = {
797800
{"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_},
798801
{"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_},
802+
{"min_length", (std::byte*)h_min_length_, nullptr},
799803
{"runtime_top_k", (std::byte*)h_runtime_top_k_, nullptr},
800804
{"runtime_top_p", (std::byte*)h_runtime_top_p_, nullptr},
801805
{"temperature", (std::byte*)h_temperature_, nullptr},
@@ -828,6 +832,8 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
828832
(uintptr_t*)allocator_->reMalloc(h_v_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true);
829833

830834
for (auto& s : states_) {
835+
s.h_prompt_length =
836+
(int*)allocator_->reMalloc(s.h_prompt_length, sizeof(int) * max_batch_size, false, true);
831837
s.h_context_length =
832838
(int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true);
833839
s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true);
@@ -1060,6 +1066,12 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
10601066
}
10611067
}
10621068

1069+
// MinLengthPenalty
1070+
if (inputs.isExist("min_length")) {
1071+
inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}});
1072+
inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}});
1073+
}
1074+
10631075
// init for eos
10641076
std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_);
10651077
Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_);

src/turbomind/models/llama/LlamaBatch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
namespace turbomind {
2121

2222
struct BatchState {
23+
int* h_prompt_length; // history + input, ignore generated
2324
int* h_context_length;
2425
bool* h_finished;
2526

@@ -249,6 +250,7 @@ class LlamaBatch {
249250
uintptr_t* h_k_block_ptrs_{};
250251
uintptr_t* h_v_block_ptrs_{};
251252

253+
int* h_min_length_{};
252254
int* h_runtime_top_k_{};
253255
float* h_runtime_top_p_{};
254256
float* h_temperature_{};

0 commit comments

Comments
 (0)
0