From 908881734b1e08f7f466f8b80d697f00d8395800 Mon Sep 17 00:00:00 2001 From: irexyc Date: Sat, 27 Jan 2024 14:43:19 +0000 Subject: [PATCH 1/5] support lora --- .../turbomind/deploy/target_model/base.py | 1 + src/turbomind/models/llama/LlamaBatch.cc | 10 ++-- src/turbomind/models/llama/LlamaBatch.h | 1 + .../models/llama/LlamaDecoderLayerWeight.cc | 25 +++++++--- .../models/llama/LlamaDecoderLayerWeight.h | 2 + src/turbomind/models/llama/LlamaDenseWeight.h | 1 + src/turbomind/models/llama/LlamaFfnLayer.cc | 15 ++++-- src/turbomind/models/llama/LlamaLinear.h | 40 +++++++++++++-- src/turbomind/models/llama/LlamaV2.cc | 49 +++++++++++++++++-- src/turbomind/models/llama/LlamaV2.h | 13 ++++- src/turbomind/models/llama/LlamaWeight.cc | 4 +- src/turbomind/models/llama/LlamaWeight.h | 1 + src/turbomind/models/llama/SequenceManager.cc | 3 +- .../models/llama/llama_decoder_kernels.cu | 43 ++++++++++++++++ .../models/llama/llama_decoder_kernels.h | 4 ++ .../models/llama/unified_attention_layer.cc | 16 ++++-- src/turbomind/models/llama/unified_decoder.cc | 5 +- .../triton_backend/llama/LlamaTritonModel.cc | 3 ++ .../triton_backend/llama/LlamaTritonModel.h | 1 + 19 files changed, 204 insertions(+), 33 deletions(-) diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index 4f4ac3e4c5..9c600b4ff1 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -66,6 +66,7 @@ class TurbomindModelConfig: max_position_embeddings: int = 0 rope_scaling_factor: float = 0.0 use_logn_attn: int = 0 + lora_policy: int = 0 @classmethod def from_dict(cls, env, allow_none=False): diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index bdcd5b8519..89e0b45bf6 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -722,10 +722,12 @@ void LlamaBatch::AllocateBuffer(size_t batch_size, size_t session_len) context_decoder_input_buf_ = (T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false); - context_decoder_output_buf_ = - (T*)allocator_->reMalloc(context_decoder_output_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false); + // double buffer for lora + context_decoder_output_buf_ = (T*)allocator_->reMalloc( + context_decoder_output_buf_, sizeof(T) * max_context_token_num_ * hidden_units * 2, false); context_decoder_ids_buf_ = (int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_context_token_num_, false); + lora_mask_buf_ = (int*)allocator_->reMalloc(lora_mask_buf_, sizeof(int) * max_context_token_num_, false); tmp_k_cache_buf_ = (T*)allocator_->reMalloc( tmp_k_cache_buf_, sizeof(T) * max_context_token_num_ * local_kv_head_num * head_dim, false); @@ -850,6 +852,7 @@ void LlamaBatch::FreeBuffer() allocator_->free((void**)&context_decoder_input_buf_); allocator_->free((void**)&context_decoder_output_buf_); allocator_->free((void**)&context_decoder_ids_buf_); + allocator_->free((void**)&lora_mask_buf_); allocator_->free((void**)&tmp_k_cache_buf_); allocator_->free((void**)&tmp_v_cache_buf_); @@ -1586,7 +1589,8 @@ bool LlamaBatch::Forward(GenerationState& g, int iter) max_context_cnts[p], max_context_cnts[p], h_input_length_buf_ + first, - sequences.data()); + sequences.data(), + lora_mask_buf_); if (iter == 0) { // compute logits of inputs if requested diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 9af3b7522f..01caaefb37 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -225,6 +225,7 @@ class LlamaBatch { T* decoder_output_buf_{}; int* sequence_lengths_{}; // current sequence length int* init_ctx_lens_{}; + int* lora_mask_buf_{}; // lora float* logits_buf_{}; // combined logits float* local_logits_buf_{}; // tensor parallel local logits diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index 34c0abf86d..2f6c964eac 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -34,6 +34,7 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(size_t head_num, WeightType weight_type, int group_size, bool attn_bias, + int lora_policy, size_t tensor_para_size, size_t tensor_para_rank): head_num_(head_num), @@ -43,6 +44,7 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(size_t head_num, inter_size_(inter_size), weight_type_(weight_type), attn_bias_(attn_bias), + lora_policy_(lora_policy), tensor_para_size_(tensor_para_size), tensor_para_rank_(tensor_para_rank) { @@ -91,7 +93,7 @@ void freeWeights(LlamaDenseWeight& weights) } template -void mallocWeights(LlamaDenseWeight& weights, bool bias) +void mallocWeights(LlamaDenseWeight& weights, bool bias, int lora_policy) { if (bias) { deviceMalloc((T**)&weights.bias, weights.output_dims); @@ -99,6 +101,9 @@ void mallocWeights(LlamaDenseWeight& weights, bool bias) const size_t bit_size = getBitSize(weights.type); if (bit_size >= 16) { // fp16, fp32 deviceMalloc((T**)&weights.kernel, weights.input_dims * weights.output_dims); + if (lora_policy) { + deviceMalloc((T**)&weights.lora_kernel, weights.input_dims * weights.output_dims); + } } else { // int8, int4 const int factor = sizeof(float) * 8 / bit_size; @@ -244,6 +249,12 @@ void loadWeights(LlamaDenseWeight& w, } } loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type, weight_slices); + if (w.lora_kernel) { + auto dot_pos = prefix.rfind("."); + auto lora_weight_file = prefix.substr(0, dot_pos) + ".lora" + prefix.substr(dot_pos) + ".weight"; + TM_LOG_INFO("loading %s", lora_weight_file.c_str()); + loadWeightFromBin((T*)w.lora_kernel, {dim0, dim1}, lora_weight_file, type, weight_slices); + } } else { // int8, int4 const int factor = sizeof(float) * 8 / bit_size; @@ -265,19 +276,19 @@ void LlamaDecoderLayerWeight::mallocWeights() deviceMalloc((T**)&self_attn_norm_weights, hidden_units_); deviceMalloc((T**)&ffn_norm_weights, hidden_units_); - turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_); - turbomind::mallocWeights(self_attn_weights.output, attn_bias_); + turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_, lora_policy_); + turbomind::mallocWeights(self_attn_weights.output, attn_bias_, lora_policy_); self_attn_weights.past_kv_scale = {1.f, 0.f, 1.f, 0.f}; if (weight_type_ == WeightType::kINT4) { - turbomind::mallocWeights(ffn_weights.fused_gating_intermediate, false); + turbomind::mallocWeights(ffn_weights.fused_gating_intermediate, false, lora_policy_); } else { - turbomind::mallocWeights(ffn_weights.gating, false); - turbomind::mallocWeights(ffn_weights.intermediate, false); + turbomind::mallocWeights(ffn_weights.gating, false, lora_policy_); + turbomind::mallocWeights(ffn_weights.intermediate, false, lora_policy_); } - turbomind::mallocWeights(ffn_weights.output, false); + turbomind::mallocWeights(ffn_weights.output, false, lora_policy_); } template diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h index 169a3aa9e6..0c36b7f601 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h @@ -36,6 +36,7 @@ struct LlamaDecoderLayerWeight { WeightType weight_type, int group_size, bool attn_bias, + int lora_policy, size_t tensor_para_size, size_t tensor_para_rank); ~LlamaDecoderLayerWeight(); @@ -60,6 +61,7 @@ struct LlamaDecoderLayerWeight { WeightType weight_type_; size_t bit_size_; bool attn_bias_; + int lora_policy_; size_t tensor_para_size_; size_t tensor_para_rank_; bool is_maintain_buffer_ = false; diff --git a/src/turbomind/models/llama/LlamaDenseWeight.h b/src/turbomind/models/llama/LlamaDenseWeight.h index 369f26c736..05408f40fe 100644 --- a/src/turbomind/models/llama/LlamaDenseWeight.h +++ b/src/turbomind/models/llama/LlamaDenseWeight.h @@ -59,6 +59,7 @@ struct LlamaDenseWeight { size_t input_dims; size_t output_dims; void* kernel; + void* lora_kernel; WeightType type; T* bias; T* scales_and_zeros; diff --git a/src/turbomind/models/llama/LlamaFfnLayer.cc b/src/turbomind/models/llama/LlamaFfnLayer.cc index 42575af665..fb55abb715 100644 --- a/src/turbomind/models/llama/LlamaFfnLayer.cc +++ b/src/turbomind/models/llama/LlamaFfnLayer.cc @@ -88,6 +88,14 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, const T* ffn_input_data = input_tensors->at("ffn_input").getPtr(); T* ffn_output_data = output_tensors->at("ffn_output").getPtr(); + // lora + int* lora_mask = nullptr; + if (input_tensors->isExist("lora_mask")) { + lora_mask = input_tensors->at("lora_mask").getPtr(); + inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, sizeof(T) * num_token * inter_size_ * 2, false); + gating_buf_ = (T*)allocator_->reMalloc(gating_buf_, sizeof(T) * num_token * inter_size_ * 2, false); + } + if (weights->fused_gating_intermediate.kernel) { NvtxScope scope("fused_silu_ffn"); linear_.forward( @@ -96,11 +104,12 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, else { { // w1(x) NvtxScope scope("w1"); - linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating); + linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating, LlamaLinear::kGemm, lora_mask); } { // w3(x) NvtxScope scope("w3"); - linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate); + linear_.forward( + inter_buf_, ffn_input_data, num_token, weights->intermediate, LlamaLinear::kGemm, lora_mask); } // silu(w1(x)) * w3(x) activation(num_token); @@ -108,7 +117,7 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, { // w2(x) NvtxScope scope("w2"); - linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output); + linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output, LlamaLinear::kGemm, lora_mask); } if (tensor_para_.world_size_ > 1) { diff --git a/src/turbomind/models/llama/LlamaLinear.h b/src/turbomind/models/llama/LlamaLinear.h index a3717b2a90..02ac80d9de 100644 --- a/src/turbomind/models/llama/LlamaLinear.h +++ b/src/turbomind/models/llama/LlamaLinear.h @@ -4,6 +4,7 @@ #include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h" +#include "src/turbomind/models/llama/llama_decoder_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cuda_utils.h" @@ -25,14 +26,18 @@ class LlamaLinear { { } - void - forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type = kGemm) + void forward(T* output_data, + const T* input_data, + int batch_size, + const LlamaDenseWeight& weight, + Type type = kGemm, + int* lora_mask = nullptr) { switch (weight.type) { case WeightType::kFP16: case WeightType::kFP32: case WeightType::kBF16: - forwardFp(output_data, input_data, batch_size, weight, type); + forwardFp(output_data, input_data, batch_size, weight, type, lora_mask); break; case WeightType::kINT4: forwardInt4(output_data, input_data, batch_size, weight, type); @@ -43,7 +48,12 @@ class LlamaLinear { } private: - void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type) + void forwardFp(T* output_data, + const T* input_data, + int batch_size, + const LlamaDenseWeight& weight, + Type type, + int* lora_mask) { FT_CHECK(type == kGemm); cublas_wrapper_->Gemm(CUBLAS_OP_N, @@ -58,6 +68,28 @@ class LlamaLinear { output_data, weight.output_dims); sync_check_cuda_error(); + + if (lora_mask && weight.lora_kernel) { + cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + weight.output_dims, + batch_size, + weight.input_dims, + (const T*)weight.lora_kernel, + weight.output_dims, + input_data, + weight.input_dims, + output_data + batch_size * weight.output_dims, + weight.output_dims); + + invokeMaskAddTwoLinearOutput(output_data, + output_data + batch_size * weight.output_dims, + lora_mask, + batch_size, + weight.output_dims, + stream_); + sync_check_cuda_error(); + } } void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type) diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 87772c3e3d..4b1767ab81 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -63,6 +63,7 @@ LlamaV2::LlamaV2(size_t head_num, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward, + int lora_policy, cudaDeviceProp* cuda_device_prop): head_num_(head_num), size_per_head_(size_per_head), @@ -84,6 +85,7 @@ LlamaV2::LlamaV2(size_t head_num, allocator_(allocator), is_free_buffer_after_forward_(is_free_buffer_after_forward), cuda_device_prop_(cuda_device_prop), + lora_policy_(lora_policy), debug_(isDebug()), shared_state_(shared_state) @@ -166,10 +168,20 @@ void LlamaV2::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba } template -void LlamaV2::updateEmbedding(T* decoder_input, const int bsz, const int* h_input_length, const Sequence** sequences) +void LlamaV2::updateEmbedding(T* decoder_input, + const int bsz, + const int* h_input_length, + const Sequence** sequences, + int token_num, + int* lora_mask, + bool* have_embeddings) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); + std::vector mask(token_num); + int* mask_ptr = mask.data(); + *have_embeddings = false; + for (int i = 0; i < bsz; i++) { const auto& seq = *sequences[i]; const auto& embeddings = seq.input_embeddings; @@ -177,18 +189,33 @@ void LlamaV2::updateEmbedding(T* decoder_input, const int bsz, const int* h_i for (int j = embeddings.size() - 1; j >= 0; j--) { int begin = ranges[j].first; int end = ranges[j].second; + if (seq.cache_len + h_input_length[i] - 1 < begin) { + continue; + } if (end <= seq.cache_len) { break; } - int off_dst = std::max(0, begin - seq.cache_len); - int off_src = std::max(0, seq.cache_len - begin); + int off_dst = std::max(0, begin - seq.cache_len); + int off_src = std::max(0, seq.cache_len - begin); + // calculate union of [begin, end) and [seq.cache_len, seq.cache_len + h_input_length[i]) + begin = std::max(begin, seq.cache_len); + end = std::min(end, seq.cache_len + h_input_length[i]); size_t byte_size = (end - begin) * hidden_units_ * sizeof(T); T* dst_ptr = decoder_input + off_dst * hidden_units_; auto src_ptr = embeddings[j].data() + off_src * hidden_units_ * sizeof(T); cudaMemcpyAsync(dst_ptr, src_ptr, byte_size, cudaMemcpyDefault, stream_); + std::fill_n(mask_ptr + off_dst, (end - begin), 1); + *have_embeddings = true; } decoder_input += h_input_length[i] * hidden_units_; + mask_ptr += h_input_length[i]; + } + + if (lora_policy_ && *have_embeddings) { + cudaMemcpyAsync(lora_mask, mask.data(), sizeof(int) * token_num, cudaMemcpyDefault, stream_); + cudaStreamSynchronize(stream_); } + sync_check_cuda_error(); } @@ -216,7 +243,8 @@ void LlamaV2::forwardUnified(T* out, int pf_max_context_len, int pf_session_len, const int* h_input_length, - const Sequence** sequences) + const Sequence** sequences, + int* lora_mask) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -233,7 +261,14 @@ void LlamaV2::forwardUnified(T* out, hidden_units_, stream_); - updateEmbedding(decoder_input, dc_batch_size + pf_batch_size, h_input_length, sequences); + bool have_embeddings = false; + updateEmbedding(decoder_input, + dc_batch_size + pf_batch_size, + h_input_length, + sequences, + token_num, + lora_mask, + &have_embeddings); sync_check_cuda_error(); @@ -262,6 +297,10 @@ void LlamaV2::forwardUnified(T* out, {"tmp_v", {MEMORY_GPU, TYPE_UINT64, {bsz}, pf_tmp_v_ptrs}}, {"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, out}}}; + if (lora_policy_ && have_embeddings && lora_mask) { + inputs.insert({"lora_mask", {MEMORY_GPU, TYPE_INT32, {token_num}, lora_mask}}); + } + unified_decoder_->forward(&outputs, &inputs, &weights_->decoder_layer_weights); } diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index 551b7cb121..6354223a97 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -73,6 +73,7 @@ class LlamaV2 { cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward, + int lora_policy, cudaDeviceProp* cuda_device_prop); struct Control { @@ -107,7 +108,13 @@ class LlamaV2 { void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step); - void updateEmbedding(T* decoder_input, const int bsz, const int* h_input_length, const Sequence** sequences); + void updateEmbedding(T* decoder_input, + const int bsz, + const int* h_input_length, + const Sequence** sequences, + int token_num, + int* lora_mask, + bool* have_embeddings); void forwardUnified(T* out, T* decoder_output, @@ -132,7 +139,8 @@ class LlamaV2 { int pf_max_context_len, int pf_session_len, const int* h_input_length, - const Sequence** sequences); + const Sequence** sequences, + int* lora_mask); void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size); @@ -163,6 +171,7 @@ class LlamaV2 { const size_t vocab_size_; size_t vocab_size_padded_; float rmsnorm_eps_ = 1e-6f; + const int lora_policy_{}; const LlamaAttentionParams attn_params_; diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 6e62eaf420..eff8fa822b 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -32,6 +32,7 @@ LlamaWeight::LlamaWeight(size_t head_num, bool attn_bias, WeightType weight_type, int group_size, + int lora_policy, size_t tensor_para_size, size_t tensor_para_rank): hidden_units_(head_num * size_per_head), @@ -56,6 +57,7 @@ LlamaWeight::LlamaWeight(size_t head_num, weight_type_, group_size, attn_bias, + lora_policy, tensor_para_size_, tensor_para_rank_)); } @@ -90,7 +92,7 @@ template void LlamaWeight::loadModel(std::string dir_path) { FtCudaDataType model_file_type = FtCudaDataType::FP16; - if(weight_type_ == WeightType::kBF16){ + if (weight_type_ == WeightType::kBF16) { model_file_type = FtCudaDataType::BF16; } dir_path += '/'; diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index a896a87a09..abbb91f241 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -37,6 +37,7 @@ struct LlamaWeight { bool attn_bias, WeightType weight_type, int group_size, + int lora_policy, size_t tensor_para_size, size_t tensor_para_rank); diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc index 9765b6e02e..dc34a0562e 100644 --- a/src/turbomind/models/llama/SequenceManager.cc +++ b/src/turbomind/models/llama/SequenceManager.cc @@ -22,8 +22,7 @@ SequenceManager::SequenceManager(size_t layer_num, size_t elem_bits, int rank, IAllocator* allocator): - block_seq_len_(block_seq_len), - rank_(rank) + block_seq_len_(block_seq_len), rank_(rank) { constexpr int kBitsPerByte = 8; diff --git a/src/turbomind/models/llama/llama_decoder_kernels.cu b/src/turbomind/models/llama/llama_decoder_kernels.cu index 6bdfa2c5e6..6945d61608 100644 --- a/src/turbomind/models/llama/llama_decoder_kernels.cu +++ b/src/turbomind/models/llama/llama_decoder_kernels.cu @@ -188,11 +188,54 @@ void invokeFusedAddBiasResidualRMSNorm( residual, in_out, bias, scale, eps, batch_size, n_dims); } +template +__global__ void +maskAddTwoLinearOutput(T* __restrict__ output1, T* __restrict__ output2, const int* __restrict__ mask, int dim) +{ + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + + const auto batch_idx = block.group_index().x; + if (!mask[batch_idx]) { + return; + } + + uint4* __restrict__ out1_ptr = reinterpret_cast(output1 + batch_idx * dim); + uint4* __restrict__ out2_ptr = reinterpret_cast(output2 + batch_idx * dim); + + res_norm_t ops; + constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); + float thread_sum{}; + for (auto i = block.thread_rank(); i < dim / PACK_DIM; i += block.size()) { + auto o1 = out1_ptr[i]; + auto o2 = out2_ptr[i]; + uint4 b = uint4{}; + o1 = ops.addvec(o1, o2, b, thread_sum); + out1_ptr[i] = o1; + } +} + +template +void invokeMaskAddTwoLinearOutput(T* output1, T* output2, const int* mask, int batch_size, int dim, cudaStream_t stream) +{ + constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); + FT_CHECK(dim % PACK_DIM == 0); + const int n_pack = dim / PACK_DIM; + const int n_iter = ((n_pack + 1023) / 1024); // iterations when block size == 1024 + int n_threads = (n_pack + n_iter - 1) / n_iter; // adjust block size to avoid tail effect + n_threads = (n_threads + 31) / 32 * 32; // round up to the nearest multiple of warp size + maskAddTwoLinearOutput<<>>(output1, output2, mask, dim); +} + +template void invokeMaskAddTwoLinearOutput(float*, float*, const int*, int, int, cudaStream_t); +template void invokeMaskAddTwoLinearOutput(half*, half*, const int*, int, int, cudaStream_t); + template void invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t); template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t); #ifdef ENABLE_BF16 template void invokeFusedAddBiasResidualRMSNorm( __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t); +template void invokeMaskAddTwoLinearOutput(__nv_bfloat16*, __nv_bfloat16*, const int*, int, int, cudaStream_t); #endif } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_decoder_kernels.h b/src/turbomind/models/llama/llama_decoder_kernels.h index ade0dc053c..5e4593cc95 100644 --- a/src/turbomind/models/llama/llama_decoder_kernels.h +++ b/src/turbomind/models/llama/llama_decoder_kernels.h @@ -8,4 +8,8 @@ template void invokeFusedAddBiasResidualRMSNorm( T* residual, T* in_out, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream); +template +void invokeMaskAddTwoLinearOutput( + T* output1, T* output2, const int* mask, int batch_size, int dim, cudaStream_t stream); + } // namespace turbomind diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index aeb8c5db48..839360cb6b 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -49,8 +49,9 @@ void UnifiedAttentionLayer::allocateBuffer(size_t num_token, const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; - // no padding - qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, false); + // no padding, double buffer for lora + qkv_buf_ = + (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_ * 2, false); // qkv_buf_3_ padding is removed qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, false); @@ -179,6 +180,11 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa T* attention_out = outputs->getPtr("hidden_features"); + int* lora_mask = nullptr; + if (inputs->isExist("lora_mask")) { + lora_mask = inputs->at("lora_mask").getPtr(); + } + ///////////////////////////////////////////// /// allocate buffers allocateBuffer(num_token, // @@ -194,7 +200,7 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa ////////////////////////////////////////////// /// qkv gemm // [token_num, hidden_dim] -> [token_num, 3, local_hidden_dim] - linear_.forward(qkv_buf_, attention_input, num_token, weights->qkv); + linear_.forward(qkv_buf_, attention_input, num_token, weights->qkv, LlamaLinear::kGemm, lora_mask); if (pf_batch_size) { const int offset = dc_batch_size; @@ -240,7 +246,7 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa ////////////////////////////////////////////// /// output gemm -> - linear_.forward(attention_out, qkv_buf_3_, num_token, weights->output); + linear_.forward(attention_out, qkv_buf_3_, num_token, weights->output, LlamaLinear::kGemm, lora_mask); if (tensor_para_.world_size_ > 1) { NcclGuard nccl_guard(tensor_para_, stream_); @@ -628,6 +634,6 @@ template class UnifiedAttentionLayer; template class UnifiedAttentionLayer; #ifdef ENABLE_BF16 template class UnifiedAttentionLayer<__nv_bfloat16>; -#endif // ENABLE_BF16 +#endif // ENABLE_BF16 } // namespace turbomind diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index 8617738d2f..bccea1fdde 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -219,6 +219,9 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con /// feed-forward network TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, decoder_output}}}; TensorMap ffn_outputs{{"ffn_output", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, decoder_output}}}; + if (inputs->isExist("lora_mask")) { + ffn_inputs.insert({"lora_mask", inputs->at("lora_mask")}); + } ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &weights->at(layer)->ffn_weights); const bool is_last_layer = layer == num_layer_ - 1; @@ -263,6 +266,6 @@ template class UnifiedDecoder; template class UnifiedDecoder; #ifdef ENABLE_BF16 template class UnifiedDecoder<__nv_bfloat16>; -#endif // ENABLE_BF16 +#endif // ENABLE_BF16 } // namespace turbomind diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 77f6b19833..e45e9fe679 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -183,6 +183,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); group_size_ = reader.GetInteger("llama", "group_size", 0); + lora_policy_ = reader.GetInteger("llama", "lora_policy", 0); // rotary embedding parameters attn_params_.rotary_embedding_dim = reader.GetInteger("llama", "rotary_embedding"); @@ -308,6 +309,7 @@ std::unique_ptr> LlamaTritonModel::createSh cublas_wrapper.get(), allocator.get(), false, // is_free_buffer_after_forward, + lora_policy_, cuda_device_prop_ptr.get()); return std::make_unique>( @@ -367,6 +369,7 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) attn_bias_, weight_type_, group_size_, + lora_policy_, tensor_para_size_, tensor_para_rank); // model inited with model_dir diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index ff086a9099..c057eb6216 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -101,6 +101,7 @@ struct LlamaTritonModel: public AbstractTransformerModel { bool attn_bias_; int quant_policy_; int group_size_; + int lora_policy_; // shared weights for each device std::vector>> shared_weights_; From 3f1c691ce98b42db691cba49185951b841ea7002 Mon Sep 17 00:00:00 2001 From: Chen Xin Date: Fri, 26 Jan 2024 19:00:43 +0800 Subject: [PATCH 2/5] support `min_length` for turbomind backend (#961) * support min_length * fix lint * disable by default * fix step * use min_new_tokens --- lmdeploy/messages.py | 5 ++- lmdeploy/turbomind/turbomind.py | 4 +++ .../kernels/sampling_penalty_kernels.cu | 5 ++- .../sampling_layers/BaseSamplingLayer.cc | 33 ++++++++++++------- .../sampling_layers/BaseSamplingLayer.h | 1 + src/turbomind/models/llama/LlamaBatch.cc | 12 +++++++ src/turbomind/models/llama/LlamaBatch.h | 2 ++ 7 files changed, 47 insertions(+), 15 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index cf5ff7062d..66c4d76810 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -30,6 +30,8 @@ class GenerationConfig: random_seed (int): Seed used when sampling a token stop_words (List[str]): Words that stop generating further tokens bad_words (List[str]): Words that the engine will never generate + min_new_tokens (int): The minimum numbers of tokens to generate, + ignoring the number of tokens in the prompt. """ n: int = 1 @@ -42,6 +44,7 @@ class GenerationConfig: random_seed: int = None stop_words: List[str] = None bad_words: List[str] = None + min_new_tokens: int = None @dataclass @@ -65,7 +68,7 @@ def From(gen_config: GenerationConfig, tokenizer: Tokenizer): >>> tokenizer = Tokenizer('internlm/internlm-chat-7b') >>> gen_config = GenerationConfig(stop_words=['']) >>> gen_config = EngineGenerationConfig.From(gen_config, tokenizer) - """ # noqa E501 + """ # noqa E501 def special_word_token_ids(words): if words is not None: diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 4a4dc91577..9febedefe9 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -648,6 +648,10 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): inputs['input_embeddings'] = input_embeddings inputs['input_embedding_ranges'] = input_embedding_ranges + if gen_config.min_new_tokens is not None: + inputs['min_length'] = _broadcast_np(gen_config.min_new_tokens, + np.int32) + bad_words = [] if gen_config.bad_words is not None: bad_words.extend(gen_config.bad_words) diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu index f7ebfeff03..28bf43aac9 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.cu +++ b/src/turbomind/kernels/sampling_penalty_kernels.cu @@ -497,9 +497,8 @@ __global__ void batchApplyMinLengthPenalty(T* logits, const int vocab_size_padded) { int bid = threadIdx.x + blockIdx.x * blockDim.x; // batch index - // We need +1 because sequence_lengths = max_input_length + num_gen_tokens - 1, - // which is equal to the length of k/v caches. - if (sequence_lengths[bid] + 1 - max_input_length < min_lengths[bid]) { + // In decoder, sequence_lengths means length of sequence that has kv cache already computed + if (sequence_lengths[bid] + 1 < min_lengths[bid]) { T mask_val = (std::is_same::value) ? -65504.0f : -FLT_MAX; logits[bid * vocab_size_padded + end_ids[bid]] = mask_val; } diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc index 1c9ae099d9..91b6809f3f 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc @@ -45,6 +45,7 @@ void BaseSamplingLayer::allocateBuffer(size_t batch_size, Tensor top_k, Tenso repetition_penalty_ = (float*)std::realloc((void*)repetition_penalty_, batch_size * sizeof(float)); min_lengths_ = (int*)std::realloc((void*)min_lengths_, batch_size * sizeof(int)); skip_decode_ = (bool*)std::realloc((void*)skip_decode_, batch_size * sizeof(bool)); + context_length_ = (int*)std::realloc((void*)context_length_, batch_size * sizeof(int)); is_allocate_buffer_ = true; } @@ -63,6 +64,7 @@ void BaseSamplingLayer::freeBuffer() std::free(repetition_penalty_); std::free(min_lengths_); std::free(skip_decode_); + std::free(context_length_); is_allocate_buffer_ = false; } } @@ -161,16 +163,23 @@ void BaseSamplingLayer::setup(const size_t batch_size, const size_t beam_widt repetition_penalty_type_ = RepetitionPenaltyType::None; } - const int default_min_length = 0; - Tensor min_lengths = runtime_args->at("min_length", Tensor(MEMORY_CPU, TYPE_INT32, {1}, &default_min_length)); - if (min_lengths.size() == 1) { - int minlen = min_lengths.getVal(); - deviceFill(min_lengths_buf_, batch_size, minlen, stream_); - std::fill_n(min_lengths_, batch_size, minlen); + // min_length + if (runtime_args->isExist("min_length")) { + Tensor min_lengths = runtime_args->at("min_length"); + Tensor context_lengths = runtime_args->at("context_length"); + Tensor prompt_lengths = runtime_args->at("prompt_length"); + auto p1 = min_lengths.getPtr(); + auto p2 = prompt_lengths.getPtr(); + for (int i = 0; i < batch_size; i++) { + min_lengths_[i] = p1[i] + p2[i]; + } + cudaAutoCpy(min_lengths_buf_, min_lengths_, batch_size, stream_); + std::copy_n(context_lengths.getPtr(), batch_size, context_length_); } else { - cudaAutoCpy(min_lengths_buf_, min_lengths.getPtr(), batch_size, stream_); - std::copy_n(min_lengths.getPtr(), batch_size, min_lengths_); + std::fill_n(min_lengths_, batch_size, 0); + deviceFill(min_lengths_buf_, batch_size, 0, stream_); + std::fill_n(context_length_, batch_size, 0); } } @@ -300,10 +309,12 @@ void BaseSamplingLayer::forward(TensorMap* output_tensors, TensorMap* input_t } } - const int num_generated_tokens = step - max_input_length; - const int* min_lengths = min_lengths_ + ite * local_batch_size; + const int num_generated_tokens = step - max_input_length; + const int* min_lengths = min_lengths_ + ite * local_batch_size; + std::vector index(local_batch_size); + std::iota(index.begin(), index.end(), 0); const bool invoke_min_length_penalty = std::any_of( - min_lengths, min_lengths + local_batch_size, [&](int min_length) { return min_length > num_generated_tokens; }); + index.begin(), index.end(), [&](int i) { return min_lengths[i] > context_length_[i] + num_generated_tokens; }); if (invoke_min_length_penalty) { FT_CHECK_WITH_INFO(input_tensors->isExist("end_id"), "Need end_id to apply min length penlaty"); invokeMinLengthPenalty(logits, diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h index 29462e16a2..68cf79c871 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h @@ -47,6 +47,7 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer { int* min_lengths_ = nullptr; bool* skip_decode_ = nullptr; bool skip_any_ = false; + int* context_length_ = nullptr; RepetitionPenaltyType repetition_penalty_type_ = RepetitionPenaltyType::None; diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 89e0b45bf6..c6bfc1502e 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -328,6 +328,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) } // total context length (history + input) + state.h_prompt_length[idx] = output_ids - output_ids_base; state.h_context_length[idx] = output_ids - output_ids_base; state.h_finished[idx] = false; @@ -698,6 +699,7 @@ void LlamaBatch::CopyState(const std::vectorh_prompt_length[di] = s->h_prompt_length[si]; d->h_context_length[di] = s->h_context_length[si]; d->h_finished[di] = s->h_finished[si]; d->h_rope_theta[di] = s->h_rope_theta[si]; @@ -774,6 +776,7 @@ void LlamaBatch::AllocatePersistantBuffer(size_t max_batch_size) h_bad_words_ = (int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true); + h_min_length_ = (int*)allocator_->reMalloc(h_min_length_, sizeof(int) * max_batch_size, true, true); h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true); h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true); h_temperature_ = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true); @@ -796,6 +799,7 @@ void LlamaBatch::AllocatePersistantBuffer(size_t max_batch_size) sampling_params_ = { {"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_}, {"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_}, + {"min_length", (std::byte*)h_min_length_, nullptr}, {"runtime_top_k", (std::byte*)h_runtime_top_k_, nullptr}, {"runtime_top_p", (std::byte*)h_runtime_top_p_, nullptr}, {"temperature", (std::byte*)h_temperature_, nullptr}, @@ -828,6 +832,8 @@ void LlamaBatch::AllocatePersistantBuffer(size_t max_batch_size) (uintptr_t*)allocator_->reMalloc(h_v_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true); for (auto& s : states_) { + s.h_prompt_length = + (int*)allocator_->reMalloc(s.h_prompt_length, sizeof(int) * max_batch_size, false, true); s.h_context_length = (int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true); s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true); @@ -1060,6 +1066,12 @@ void LlamaBatch::InitializeSampling(const GenerationState& g) } } + // MinLengthPenalty + if (inputs.isExist("min_length")) { + inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}}); + inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}}); + } + // init for eos std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_); Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_); diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 01caaefb37..f040448301 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -20,6 +20,7 @@ namespace turbomind { struct BatchState { + int* h_prompt_length; // history + input, ignore generated int* h_context_length; bool* h_finished; @@ -249,6 +250,7 @@ class LlamaBatch { uintptr_t* h_k_block_ptrs_{}; uintptr_t* h_v_block_ptrs_{}; + int* h_min_length_{}; int* h_runtime_top_k_{}; float* h_runtime_top_p_{}; float* h_temperature_{}; From 955fd24e6359c895ae3a09ca2ebc89c6df98c180 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 25 Jan 2024 01:07:25 +0800 Subject: [PATCH 3/5] repetition penalty for long context --- .../kernels/sampling_penalty_kernels.cu | 54 +++++++++++-------- .../kernels/sampling_penalty_kernels.h | 1 + .../sampling_layers/BaseSamplingLayer.cc | 4 ++ .../sampling_layers/BaseSamplingLayer.h | 2 + tests/csrc/unittests/test_penalty_kernels.cu | 11 ++-- 5 files changed, 47 insertions(+), 25 deletions(-) diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu index 28bf43aac9..dd1288c56f 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.cu +++ b/src/turbomind/kernels/sampling_penalty_kernels.cu @@ -367,6 +367,7 @@ template void invokeApplyRepetitionPenalty(half* logits, template __global__ void batchApplyRepetitionPenalty(T* logits, const float* penalties, + int* penalty_workspace, const int* output_ids, const int batch_size, const int vocab_size, @@ -374,11 +375,13 @@ __global__ void batchApplyRepetitionPenalty(T* logits, const int max_input_length, const int step) { - extern __shared__ float penalty_logits[]; - int* penalty_indices = (int*)(penalty_logits + step); - const int batch_idx = blockIdx.x; - const float penalty = penalties[batch_idx]; - const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; + const int batch_idx = blockIdx.x; + const float penalty = penalties[batch_idx]; + const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; + + penalty_workspace += batch_idx * step * 2; + float* penalty_logits = (float*)penalty_workspace; + int* penalty_indices = (int*)(penalty_workspace + step); logits += batch_idx * vocab_size; @@ -409,10 +412,6 @@ __global__ void batchApplyRepetitionPenalty(T* logits, } } - if (blockDim.x > 32) { - __syncthreads(); - } - // Phase 2. Replace a logit value by the penalized one. for (int index = threadIdx.x; index < step; index += blockDim.x) { // Skip the padding tokens in input sequences. @@ -426,6 +425,7 @@ __global__ void batchApplyRepetitionPenalty(T* logits, template void invokeBatchApplyRepetitionPenalty(T* logits, const float* penalties, + int* penalty_workspace, const int* output_ids, const int batch_size, const int local_batch_size, @@ -442,22 +442,30 @@ void invokeBatchApplyRepetitionPenalty(T* logits, // output_ids [step, batch_size] : output token ids (with offset ite * local_batch_size). // input_lengths [local_batch_size], input lengths (optional). // Padding tokens at [input_length, max_input_length) of input will not be penalized. - dim3 block(min(step, 1024)); - dim3 grid(local_batch_size); - size_t smem_size = step * (sizeof(float) + sizeof(int)); + dim3 block(min(step, 1024)); + dim3 grid(local_batch_size); if (penalty_type == RepetitionPenaltyType::Additive) { - check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - batchApplyRepetitionPenalty<<>>( - logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); + batchApplyRepetitionPenalty<<>>(logits, + penalties, + penalty_workspace, + output_ids, + batch_size, + vocab_size, + input_lengths, + max_input_length, + step); } else if (penalty_type == RepetitionPenaltyType::Multiplicative) { - check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - batchApplyRepetitionPenalty<<>>( - logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); + batchApplyRepetitionPenalty + <<>>(logits, + penalties, + penalty_workspace, + output_ids, + batch_size, + vocab_size, + input_lengths, + max_input_length, + step); } else if (penalty_type == RepetitionPenaltyType::None) { // do nothing @@ -466,6 +474,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits, template void invokeBatchApplyRepetitionPenalty(float* logits, const float* penalties, + int* penalty_workspace, const int* output_ids, const int batch_size, const int local_batch_size, @@ -478,6 +487,7 @@ template void invokeBatchApplyRepetitionPenalty(float* logits, template void invokeBatchApplyRepetitionPenalty(half* logits, const float* penalties, + int* penalty_workspace, const int* output_ids, const int batch_size, const int local_batch_size, diff --git a/src/turbomind/kernels/sampling_penalty_kernels.h b/src/turbomind/kernels/sampling_penalty_kernels.h index 3c54cc15bf..e12698cdf7 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.h +++ b/src/turbomind/kernels/sampling_penalty_kernels.h @@ -40,6 +40,7 @@ void invokeApplyRepetitionPenalty(T* logits, template void invokeBatchApplyRepetitionPenalty(T* logits, const float* penalties, + int* penalty_workspace, const int* output_ids, const int batch_size, const int local_batch_size, diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc index 91b6809f3f..3e23cbd616 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc @@ -55,6 +55,7 @@ void BaseSamplingLayer::freeBuffer() { TM_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_) { + allocator_->free((void**)(&repetition_penalty_workspace_)); allocator_->free((void**)(&temperature_buf_)); allocator_->free((void**)(&repetition_penalty_buf_)); allocator_->free((void**)(&min_lengths_buf_)); @@ -293,9 +294,12 @@ void BaseSamplingLayer::forward(TensorMap* output_tensors, TensorMap* input_t if (step > 1 && repetition_penalty_type_ != RepetitionPenaltyType::None) { float default_value = getDefaultPenaltyValue(repetition_penalty_type_); if (!ALL_OF(repetition_penalty_ + ite * local_batch_size, local_batch_size, float, default_value)) { + repetition_penalty_workspace_ = reinterpret_cast(allocator_->reMalloc( + repetition_penalty_workspace_, batch_size * step * (sizeof(int) + sizeof(float)), false)); invokeBatchApplyRepetitionPenalty( logits, repetition_penalty_buf_ + ite * local_batch_size, + repetition_penalty_workspace_ + ite * local_batch_size, output_tensors->at("output_ids").getPtrWithOffset(ite * local_batch_size), batch_size, local_batch_size, diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h index 68cf79c871..83d2c40f24 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h @@ -33,6 +33,8 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer { size_t vocab_size_; size_t vocab_size_padded_; + int* repetition_penalty_workspace_; + size_t sampling_workspace_size_; void* sampling_workspace_ = nullptr; diff --git a/tests/csrc/unittests/test_penalty_kernels.cu b/tests/csrc/unittests/test_penalty_kernels.cu index 86d23f44e6..301b79aa9f 100644 --- a/tests/csrc/unittests/test_penalty_kernels.cu +++ b/tests/csrc/unittests/test_penalty_kernels.cu @@ -18,10 +18,10 @@ #include // snprintf #include // expf, log #include -#include // rand -#include // std::string +#include // rand +#include // std::string #include -#include // std::vector +#include // std::vector #include #include @@ -386,6 +386,7 @@ protected: T* d_bias_; int* d_output_ids_; int* d_input_lengths_; + int* d_penalty_workspace_; float* d_repetition_penalties_; @@ -410,6 +411,8 @@ protected: d_bias_ = reinterpret_cast(allocator->malloc(sizeof(T) * vocab_size_padded_)); d_output_ids_ = reinterpret_cast(allocator->malloc(sizeof(int) * sequence_length_ * batch_size_)); d_input_lengths_ = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size_)); + d_penalty_workspace_ = + reinterpret_cast(allocator->malloc((sizeof(int) + sizeof(float)) * batch_size_ * step_)); cudaAutoCpy(d_logits_, h_logits_, batch_size_ * vocab_size_padded_, stream); cudaAutoCpy(d_bias_, h_bias_, vocab_size_padded_, stream); @@ -501,6 +504,7 @@ public: else { invokeBatchApplyRepetitionPenalty(d_logits_, d_repetition_penalties_, + d_penalty_workspace_, d_output_ids_, batch_size_, batch_size_, @@ -559,6 +563,7 @@ public: cudaAutoCpy(d_logits_batch, h_logits_, batch_size_ * vocab_size_padded_, stream); invokeBatchApplyRepetitionPenalty(d_logits_batch, d_repetition_penalties_, + d_penalty_workspace_, d_output_ids_, batch_size_, batch_size_, From 62f42243b2c6aa7ccee9edf0113a4d952a2a20b5 Mon Sep 17 00:00:00 2001 From: irexyc Date: Sun, 28 Jan 2024 22:16:46 +0800 Subject: [PATCH 4/5] repetition penalty output ids --- .../kernels/sampling_penalty_kernels.cu | 21 ++++++++++++++++--- .../kernels/sampling_penalty_kernels.h | 1 + .../sampling_layers/BaseSamplingLayer.cc | 12 +++++++++++ .../sampling_layers/BaseSamplingLayer.h | 1 + src/turbomind/models/llama/LlamaBatch.cc | 8 +++---- tests/csrc/unittests/test_penalty_kernels.cu | 2 ++ 6 files changed, 37 insertions(+), 8 deletions(-) diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu index dd1288c56f..df27300293 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.cu +++ b/src/turbomind/kernels/sampling_penalty_kernels.cu @@ -371,13 +371,15 @@ __global__ void batchApplyRepetitionPenalty(T* logits, const int* output_ids, const int batch_size, const int vocab_size, + const int* prompt_lengths, const int* input_lengths, const int max_input_length, const int step) { - const int batch_idx = blockIdx.x; - const float penalty = penalties[batch_idx]; - const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; + const int batch_idx = blockIdx.x; + const float penalty = penalties[batch_idx]; + const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; + const int prompt_length = prompt_lengths != nullptr ? prompt_lengths[batch_idx] : 0; penalty_workspace += batch_idx * step * 2; float* penalty_logits = (float*)penalty_workspace; @@ -388,6 +390,10 @@ __global__ void batchApplyRepetitionPenalty(T* logits, // Phase 1. Find indices to penalize and keep the penalized values. // A vocab id can appear multiple times but should be penalized once. for (int index = threadIdx.x; index < step; index += blockDim.x) { + // skip prompt + if (index < prompt_length) { + continue; + } // Skip the padding tokens in input sequences. if (index >= input_length && index < max_input_length) { continue; @@ -414,6 +420,10 @@ __global__ void batchApplyRepetitionPenalty(T* logits, // Phase 2. Replace a logit value by the penalized one. for (int index = threadIdx.x; index < step; index += blockDim.x) { + // skip prompt + if (index < prompt_length) { + continue; + } // Skip the padding tokens in input sequences. if (index >= input_length && index < max_input_length) { continue; @@ -430,6 +440,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits, const int batch_size, const int local_batch_size, const int vocab_size, + const int* prompt_lengths, const int* input_lengths, const int max_input_length, const int step, @@ -451,6 +462,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits, output_ids, batch_size, vocab_size, + prompt_lengths, input_lengths, max_input_length, step); @@ -463,6 +475,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits, output_ids, batch_size, vocab_size, + prompt_lengths, input_lengths, max_input_length, step); @@ -479,6 +492,7 @@ template void invokeBatchApplyRepetitionPenalty(float* logits, const int batch_size, const int local_batch_size, const int vocab_size, + const int* prompt_lengths, const int* input_lengths, const int max_input_length, const int step, @@ -492,6 +506,7 @@ template void invokeBatchApplyRepetitionPenalty(half* logits, const int batch_size, const int local_batch_size, const int vocab_size, + const int* prompt_lengths, const int* input_lengths, const int max_input_length, const int step, diff --git a/src/turbomind/kernels/sampling_penalty_kernels.h b/src/turbomind/kernels/sampling_penalty_kernels.h index e12698cdf7..62f35f1102 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.h +++ b/src/turbomind/kernels/sampling_penalty_kernels.h @@ -45,6 +45,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits, const int batch_size, const int local_batch_size, const int vocab_size, + const int* prompt_lengths, const int* input_lengths, const int max_input_length, const int step, diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc index 3e23cbd616..429ac7db0f 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc @@ -39,6 +39,8 @@ void BaseSamplingLayer::allocateBuffer(size_t batch_size, Tensor top_k, Tenso allocator_->reMalloc(runtime_logits_buf_, sizeof(T) * batch_size * vocab_size_padded_, false)); skip_decode_buf_ = reinterpret_cast(allocator_->reMalloc(skip_decode_buf_, sizeof(bool) * batch_size, false)); + prompt_lengths_buf_ = + reinterpret_cast(allocator_->reMalloc(prompt_lengths_buf_, sizeof(int) * batch_size, false)); // host buffers. temperature_ = (float*)std::realloc((void*)temperature_, batch_size * sizeof(float)); @@ -59,6 +61,7 @@ void BaseSamplingLayer::freeBuffer() allocator_->free((void**)(&temperature_buf_)); allocator_->free((void**)(&repetition_penalty_buf_)); allocator_->free((void**)(&min_lengths_buf_)); + allocator_->free((void**)(&prompt_lengths_buf_)); allocator_->free((void**)(&runtime_logits_buf_)); allocator_->free((void**)(&skip_decode_buf_)); std::free(temperature_); @@ -164,6 +167,14 @@ void BaseSamplingLayer::setup(const size_t batch_size, const size_t beam_widt repetition_penalty_type_ = RepetitionPenaltyType::None; } + if (runtime_args->isExist("prompt_length")) { + Tensor prompt_lengths = runtime_args->at("prompt_length"); + cudaAutoCpy(prompt_lengths_buf_, prompt_lengths.getPtr(), batch_size, stream_); + } + else { + deviceFill(prompt_lengths_buf_, batch_size, 0, stream_); + } + // min_length if (runtime_args->isExist("min_length")) { Tensor min_lengths = runtime_args->at("min_length"); @@ -304,6 +315,7 @@ void BaseSamplingLayer::forward(TensorMap* output_tensors, TensorMap* input_t batch_size, local_batch_size, vocab_size_padded_, + prompt_lengths_buf_ + ite * local_batch_size, input_tensors->at("input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {}, nullptr}).getPtr(), max_input_length, step, diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h index 83d2c40f24..6645fffa7c 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h @@ -43,6 +43,7 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer { int* min_lengths_buf_ = nullptr; bool* skip_decode_buf_ = nullptr; T* runtime_logits_buf_ = nullptr; + int* prompt_lengths_buf_ = nullptr; float* temperature_ = nullptr; float* repetition_penalty_ = nullptr; diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index c6bfc1502e..1bc3e68333 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -1066,11 +1066,9 @@ void LlamaBatch::InitializeSampling(const GenerationState& g) } } - // MinLengthPenalty - if (inputs.isExist("min_length")) { - inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}}); - inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}}); - } + // MinLengthPenalty & RepetitionPenalty + inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}}); + inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}}); // init for eos std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_); diff --git a/tests/csrc/unittests/test_penalty_kernels.cu b/tests/csrc/unittests/test_penalty_kernels.cu index 301b79aa9f..e774f30fcc 100644 --- a/tests/csrc/unittests/test_penalty_kernels.cu +++ b/tests/csrc/unittests/test_penalty_kernels.cu @@ -509,6 +509,7 @@ public: batch_size_, batch_size_, vocab_size_padded_, + nullptr, d_input_lengths_, max_input_length_, step_, @@ -568,6 +569,7 @@ public: batch_size_, batch_size_, vocab_size_padded_, + nullptr, d_input_lengths_, max_input_length_, step_, From c5111b6229a0506bf82b4687ff916e9d855a004c Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 9 May 2024 02:11:42 +0800 Subject: [PATCH 5/5] fix penalty --- src/turbomind/kernels/sampling_penalty_kernels.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu index df27300293..55c8065436 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.cu +++ b/src/turbomind/kernels/sampling_penalty_kernels.cu @@ -418,6 +418,8 @@ __global__ void batchApplyRepetitionPenalty(T* logits, } } + __syncthreads(); + // Phase 2. Replace a logit value by the penalized one. for (int index = threadIdx.x; index < step; index += blockDim.x) { // skip prompt