8000 repetition penalty output ids · InternLM/lmdeploy@62f4224 · GitHub
[go: up one dir, main page]

Skip to content

Commit 62f4224

Browse files
committed
repetition penalty output ids
1 parent 955fd24 commit 62f4224

File tree

6 files changed

+37
-8
lines changed

6 files changed

+37
-8
lines changed

src/turbomind/kernels/sampling_penalty_kernels.cu

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,15 @@ __global__ void batchApplyRepetitionPenalty(T* logits,
371371
const int* output_ids,
372372
const int batch_size,
373373
const int vocab_size,
374+
const int* prompt_lengths,
374375
const int* input_lengths,
375376
const int max_input_length,
376377
const int step)
377378
{
378-
const int batch_idx = blockIdx.x;
379-
const float penalty = penalties[batch_idx];
380-
const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length;
379+
const int batch_idx = blockIdx.x;
380+
const float penalty = penalties[batch_idx];
381+
const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length;
382+
const int prompt_length = prompt_lengths != nullptr ? prompt_lengths[batch_idx] : 0;
381383

382384
penalty_workspace += batch_idx * step * 2;
383385
float* penalty_logits = (float*)penalty_workspace;
@@ -388,6 +390,10 @@ __global__ void batchApplyRepetitionPenalty(T* logits,
388390
// Phase 1. Find indices to penalize and keep the penalized values.
389391
// A vocab id can appear multiple times but should be penalized once.
390392
for (int index = threadIdx.x; index < step; index += blockDim.x) {
393+
// skip prompt
394+
if (index < prompt_length) {
395+
continue;
396+
}
391397
// Skip the padding tokens in input sequences.
392398
if (index >= input_length && index < max_input_length) {
393399
continue;
@@ -414,6 +420,10 @@ __global__ void batchApplyRepetitionPenalty(T* logits,
414420

415421
// Phase 2. Replace a logit value by the penalized one.
416422
for (int index = threadIdx.x; index < step; index += blockDim.x) {
423+
// skip prompt
424+
if (index < prompt_length) {
425+
continue;
426+
}
417427
// Skip the padding tokens in input sequences.
418428
if (index >= input_length && index < max_input_length) {
419429
continue;
@@ -430,6 +440,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
430440
const int batch_size,
431441
const int local_batch_size,
432442
const int vocab_size,
443+
const int* prompt_lengths,
433444
const int* input_lengths,
434445
const int max_input_length,
435446
const int step,
@@ -451,6 +462,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
451462
output_ids,
452463
batch_size,
453464
vocab_size,
465+
prompt_lengths,
454466
input_lengths,
455467
max_input_length,
456468
step);
@@ -463,6 +475,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
463475
output_ids,
464476
batch_size,
465477
vocab_size,
478+
prompt_lengths,
466479
input_lengths,
467480
max_input_length,
468481
step);
@@ -479,6 +492,7 @@ template void invokeBatchApplyRepetitionPenalty(float* logits,
479492
const int batch_size,
480493
const int local_batch_size,
481494
const int vocab_size,
495+
const int* prompt_lengths,
482496
const int* input_lengths,
483497
const int max_input_length,
484498
const int step,
@@ -492,6 +506,7 @@ template void invokeBatchApplyRepetitionPenalty(half* logits,
492506
const int batch_size,
493507
const int local_batch_size,
494508
const int vocab_size,
509+
const int* prompt_lengths,
495510
const int* input_lengths,
496511
const int max_input_length,
497512
const int step,

src/turbomind/kernels/sampling_penalty_kernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
4545
const int batch_size,
4646
const int local_batch_size,
4747
const int vocab_size,
48+
const int* prompt_lengths,
4849
const int* input_lengths,
4950
const int max_input_length,
5051
const int step,

src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso
3939
allocator_->reMalloc(runtime_logits_buf_, sizeof(T) * batch_size * vocab_size_padded_, false));
4040
skip_decode_buf_ =
4141
reinterpret_cast<bool*>(allocator_->reMalloc(skip_decode_buf_, sizeof(bool) * batch_size, false));
42+
prompt_lengths_buf_ =
43+
reinterpret_cast<int*>(allocator_->reMalloc(prompt_lengths_buf_, sizeof(int) * batch_size, false));
4244

4345
// host buffers.
4446
temperature_ = (float*)std::realloc((void*)temperature_, batch_size * sizeof(float));
@@ -59,6 +61,7 @@ void BaseSamplingLayer<T>::freeBuffer()
5961
allocator_->free((void**)(&temperature_buf_));
6062
allocator_->free((void**)(&repetition_penalty_buf_));
6163
allocator_->free((void**)(&min_lengths_buf_));
64+
allocator_->free((void**)(&prompt_lengths_buf_));
6265
allocator_->free((void**)(&runtime_logits_buf_));
6366
allocator_->free((void**)(&skip_decode_buf_));
6467
std::free(temperature_);
@@ -164,6 +167,14 @@ void BaseSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
164167
repetition_penalty_type_ = RepetitionPenaltyType::None;
165168
}
166169

170+
if (runtime_args->isExist("prompt_length")) {
171+
Tensor prompt_lengths = runtime_args->at("prompt_length");
172+
cudaAutoCpy(prompt_lengths_buf_, prompt_lengths.getPtr<int>(), batch_size, stream_);
173+
}
174+
else {
175+
deviceFill(prompt_lengths_buf_, batch_size, 0, stream_);
176+
}
177+
167178
// min_length
168179
if (runtime_args->isExist("min_length")) {
169180
Tensor min_lengths = runtime_args->at("min_length");
@@ -304,6 +315,7 @@ void BaseSamplingLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_t
304315
batch_size,
305316
local_batch_size,
306317
vocab_size_padded_,
318+
prompt_lengths_buf_ + ite * local_batch_size,
307319
input_tensors->at("input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {}, nullptr}).getPtr<int>(),
308320
max_input_length,
309321
step,

src/turbomind/layers/sampling_layers/BaseSamplingLayer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer {
4343
int* min_lengths_buf_ = nullptr;
4444
bool* skip_decode_buf_ = nullptr;
4545
T* runtime_logits_buf_ = nullptr;
46+
int* prompt_lengths_buf_ = nullptr;
4647

4748
float* temperature_ = nullptr;
4849
float* repetition_penalty_ = nullptr;

src/turbomind/models/llama/LlamaBatch.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,11 +1066,9 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
10661066
}
10671067
}
10681068

1069-
// MinLengthPenalty
1070-
if (inputs.isExist("min_length")) {
1071-
inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}});
1072-
inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}});
1073-
}
1069+
// MinLengthPenalty & RepetitionPenalty
1070+
inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}});
1071+
inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}});
10741072

10751073
// init for eos
10761074
std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_);

tests/csrc/unittests/test_penalty_kernels.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ public:
509509
batch_size_,
510510
batch_size_,
511511
vocab_size_padded_,
512+
nullptr,
512513
d_input_lengths_,
513514
max_input_length_,
514515
step_,
@@ -568,6 +569,7 @@ public:
568569
batch_size_,
569570
batch_size_,
570571
vocab_size_padded_,
572+
nullptr,
571573
d_input_lengths_,
572574
max_input_length_,
573575
step_,

0 commit comments

Comments
 (0)
0