8000 Min P sampler implementation [alternative to Top P/Top K] by kalomaze · Pull Request #3841 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Min P sampler implementation [alternative to Top P/Top K] #3841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
59d1232
cuda : prints wip
ggerganov Oct 25, 2023
52af782
cuda : new cublas gemm branch for multi-batch quantized src0
ggerganov Oct 25, 2023
16b60dd
cuda : add F32 sgemm branch
ggerganov Oct 25, 2023
a3c2843
cuda : fine-tune >= VOLTA params + use MMQ only for small batches
ggerganov Oct 25, 2023
4c6744b
cuda : remove duplicated cuBLAS GEMM code
ggerganov Oct 25, 2023
a4e15a3
cuda : add CUDA_USE_TE 8000 NSOR_CORES and GGML_CUDA_FORCE_MMQ macros
ggerganov Oct 25, 2023
49af767
build : add compile option to force use of MMQ kernels
ggerganov Oct 27, 2023
a9e2b74
Super hacky starting implementation of Min P
kalomaze Oct 28, 2023
a235a0d
Transform Min P into a proper CLI option
kalomaze Oct 29, 2023
838d58d
Min P disabled if set to 1.0 or 0, otherwise Top P
kalomaze Oct 29, 2023
69ef4ca
Debugging print statements removed
kalomaze Oct 29, 2023
833637b
erring on the side of caution; disable by default
kalomaze Oct 29, 2023
62fc771
Remove accidentally kept prints + min_keep support
kalomaze Oct 29, 2023
49b68e8
Standardize 0.0 disabling min_p upon feedback
kalomaze Oct 29, 2023
6f7cdec
Simplified counter by checking candidates size
kalomaze Oct 29, 2023
cb23358
minor whitespace fix
kalomaze Oct 29, 2023
fcbbfc1
Even formatting + exclusively 0.0f to disable now
kalomaze Oct 29, 2023
69e638e
cleanup
cebtenzzre Oct 29, 2023
3ddfd67
permit simultaneous use of top_p and min_p
cebtenzzre Oct 29, 2023
18c0aa7
Merge remote-tracking branch 'original/cuda-quantum-batch' into min-p…
kalomaze Oct 29, 2023
87adfad
Merge branch 'min-p-sampling' of https://github.com/kalomaze/koboldcp…
kalomaze Oct 29, 2023
9248325
Update README & set 0.05 default
kalomaze Oct 31, 2023
512cac6
added a bit more context to the README
kalomaze Oct 31, 2023
974640a
Update README for consistency
kalomaze Oct 31, 2023
3b58af2
forgot one small thing!
kalomaze Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cuda : add CUDA_USE_TENSOR_CORES and GGML_CUDA_FORCE_MMQ macros
  • Loading branch information
ggerganov committed Oct 26, 2023
commit a4e15a36e4cd7120945eef560e591efaeb5fbd2b
122 changes: 103 additions & 19 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@
#define CC_OFFSET_AMD 1000000
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)

// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
// - 7B quantum model: +100-200 MB
// - 13B quantum model: +200-400 MB
//#define GGML_CUDA_FORCE_MMQ

// TODO: improve this to be correct for more hardware
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
// probably other such cases, and not sure what happens on AMD hardware
#if !defined(GGML_CUDA_FORCE_MMQ)
#define CUDA_USE_TENSOR_CORES
#endif

// max batch size to use MMQ kernels when tensor cores are available
#define MMQ_MAX_BATCH_SIZE 32

#if defined(GGML_USE_HIPBLAS)
#define __CUDA_ARCH__ 1300

Expand Down Expand Up @@ -470,7 +487,6 @@ static int g_device_count = -1;
static int g_main_device = 0;
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
static bool g_mul_mat_q = true;

static void * g_scratch_buffer = nullptr;
static size_t g_scratch_size = 0; // disabled by default
Expand Down Expand Up @@ -3554,9 +3570,15 @@ static __device__ __forceinline__ void mul_mat_q(
#define MMQ_X_Q4_0_RDNA1 64
#define MMQ_Y_Q4_0_RDNA1 64
#define NWARPS_Q4_0_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q4_0_AMPERE 4
#define MMQ_Y_Q4_0_AMPERE 32
#define NWARPS_Q4_0_AMPERE 4
#else
#define MMQ_X_Q4_0_AMPERE 64
#define MMQ_Y_Q4_0_AMPERE 128
#define NWARPS_Q4_0_AMPERE 4
#endif
#define MMQ_X_Q4_0_PASCAL 64
#define MMQ_Y_Q4_0_PASCAL 64
#define NWARPS_Q4_0_PASCAL 8
Expand Down Expand Up @@ -3615,9 +3637,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q4_1_RDNA1 64
#define MMQ_Y_Q4_1_RDNA1 64
#define NWARPS_Q4_1_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q4_1_AMPERE 4
#define MMQ_Y_Q4_1_AMPERE 32
#define NWARPS_Q4_1_AMPERE 4
#else
#define MMQ_X_Q4_1_AMPERE 64
#define MMQ_Y_Q4_1_AMPERE 128
#define NWARPS_Q4_1_AMPERE 4
#endif
#define MMQ_X_Q4_1_PASCAL 64
#define MMQ_Y_Q4_1_PASCAL 64
#define NWARPS_Q4_1_PASCAL 8
Expand Down Expand Up @@ -3678,9 +3706,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q5_0_RDNA1 64
#define MMQ_Y_Q5_0_RDNA1 64
#define NWARPS_Q5_0_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q5_0_AMPERE 4
#define MMQ_Y_Q5_0_AMPERE 32
#define NWARPS_Q5_0_AMPERE 4
#else
#define MMQ_X_Q5_0_AMPERE 128
#define MMQ_Y_Q5_0_AMPERE 64
#define NWARPS_Q5_0_AMPERE 4
#endif
#define MMQ_X_Q5_0_PASCAL 64
#define MMQ_Y_Q5_0_PASCAL 64
#define NWARPS_Q5_0_PASCAL 8
Expand Down Expand Up @@ -3739,9 +3773,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q5_1_RDNA1 64
#define MMQ_Y_Q5_1_RDNA1 64
#define NWARPS_Q5_1_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q5_1_AMPERE 4
#define MMQ_Y_Q5_1_AMPERE 32
#define NWARPS_Q5_1_AMPERE 4
#else
#define MMQ_X_Q5_1_AMPERE 128
#define MMQ_Y_Q5_1_AMPERE 64
#define NWARPS_Q5_1_AMPERE 4
#endif
#define MMQ_X_Q5_1_PASCAL 64
#define MMQ_Y_Q5_1_PASCAL 64
#define NWARPS_Q5_1_PASCAL 8
Expand Down Expand Up @@ -3800,9 +3840,15 @@ mul_mat_q5_1(
#define MMQ_X_Q8_0_RDNA1 64
#define MMQ_Y_Q8_0_RDNA1 64
#define NWARPS_Q8_0_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q8_0_AMPERE 4
#define MMQ_Y_Q8_0_AMPERE 32
#define NWARPS_Q8_0_AMPERE 4
#else
#define MMQ_X_Q8_0_AMPERE 128
#define MMQ_Y_Q8_0_AMPERE 64
#define NWARPS_Q8_0_AMPERE 4
#endif
#define MMQ_X_Q8_0_PASCAL 64
#define MMQ_Y_Q8_0_PASCAL 64
#define NWARPS_Q8_0_PASCAL 8
Expand Down Expand Up @@ -3861,9 +3907,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q2_K_RDNA1 128
#define MMQ_Y_Q2_K_RDNA1 32
#define NWARPS_Q2_K_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q2_K_AMPERE 4
#define MMQ_Y_Q2_K_AMPERE 32
#define NWARPS_Q2_K_AMPERE 4
#else
#define MMQ_X_Q2_K_AMPERE 64
#define MMQ_Y_Q2_K_AMPERE 128
#define NWARPS_Q2_K_AMPERE 4
#endif
#define MMQ_X_Q2_K_PASCAL 64
#define MMQ_Y_Q2_K_PASCAL 64
#define NWARPS_Q2_K_PASCAL 8
Expand Down Expand Up @@ -3922,9 +3974,15 @@ mul_mat_q2_K(
#define MMQ_X_Q3_K_RDNA1 32
#define MMQ_Y_Q3_K_RDNA1 128
#define NWARPS_Q3_K_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q3_K_AMPERE 4
#define MMQ_Y_Q3_K_AMPERE 32
#define NWARPS_Q3_K_AMPERE 4
#else
#define MMQ_X_Q3_K_AMPERE 128
#define MMQ_Y_Q3_K_AMPERE 128
#define NWARPS_Q3_K_AMPERE 4
#endif
#define MMQ_X_Q3_K_PASCAL 64
#define MMQ_Y_Q3_K_PASCAL 64
#define NWARPS_Q3_K_PASCAL 8
Expand Down Expand Up @@ -3985,9 +4043,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q4_K_RDNA1 32
#define MMQ_Y_Q4_K_RDNA1 64
#define NWARPS_Q4_K_RDNA1 8
8000 #if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q4_K_AMPERE 4
#define MMQ_Y_Q4_K_AMPERE 32
#define NWARPS_Q4_K_AMPERE 4
#else
#define MMQ_X_Q4_K_AMPERE 64
#define MMQ_Y_Q4_K_AMPERE 128
#define NWARPS_Q4_K_AMPERE 4
#endif
#define MMQ_X_Q4_K_PASCAL 64
#define MMQ_Y_Q4_K_PASCAL 64
#define NWARPS_Q4_K_PASCAL 8
Expand Down Expand Up @@ -4048,9 +4112,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q5_K_RDNA1 32
#define MMQ_Y_Q5_K_RDNA1 64
#define NWARPS_Q5_K_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q5_K_AMPERE 4
#define MMQ_Y_Q5_K_AMPERE 32
#define NWARPS_Q5_K_AMPERE 4
#else
#define MMQ_X_Q5_K_AMPERE 64
#define MMQ_Y_Q5_K_AMPERE 128
#define NWARPS_Q5_K_AMPERE 4
#endif
#define MMQ_X_Q5_K_PASCAL 64
#define MMQ_Y_Q5_K_PASCAL 64
#define NWARPS_Q5_K_PASCAL 8
Expand Down Expand Up @@ -4109,9 +4179,15 @@ mul_mat_q5_K(
#define MMQ_X_Q6_K_RDNA1 32
#define MMQ_Y_Q6_K_RDNA1 64
#define NWARPS_Q6_K_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q6_K_AMPERE 4
#define MMQ_Y_Q6_K_AMPERE 32
#define NWARPS_Q6_K_AMPERE 4
#else
#define MMQ_X_Q6_K_AMPERE 64
#define MMQ_Y_Q6_K_AMPERE 64
#define NWARPS_Q6_K_AMPERE 4
#endif
#define MMQ_X_Q6_K_PASCAL 64
#define MMQ_Y_Q6_K_PASCAL 64
#define NWARPS_Q6_K_PASCAL 8
Expand Down Expand Up @@ -5663,6 +5739,16 @@ void ggml_init_cublas() {
CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
int64_t total_vram = 0;
#if defined(GGML_CUDA_FORCE_MMQ)
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
#else
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
#endif
#if defined(CUDA_USE_TENSOR_CORES)
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
#else
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
#endif
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
for (int id = 0; id < g_device_count; ++id) {
cudaDeviceProp prop;
Expand Down Expand Up @@ -6347,7 +6433,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha, src0_ddf_i, ne00,
src1_ddf_i, ne10,
src1_ddf_i, ne10,
&beta, dst_dd_i, ldc));

if (src0_as != 0) {
Expand Down Expand Up @@ -7204,18 +7290,23 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const

static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool all_on_device =
(src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
(src0->backend == GGML_BACKEND_GPU) &&
(src1->backend == GGML_BACKEND_GPU) &&
( dst->backend == GGML_BACKEND_GPU);

int64_t min_compute_capability = INT_MAX;
for (int64_t id = 0; id < g_device_count; ++id) {
if (min_compute_capability > g_compute_capabilities[id]
&& g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
min_compute_capability = g_compute_capabilities[id];
}
}

#ifdef CUDA_USE_TENSOR_CORES
const bool use_tensor_cores = true;
#else
const bool use_tensor_cores = false;
#endif

// debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
Expand All @@ -7224,20 +7315,19 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);

if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// KQ single-batch
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
} else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
} else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// KQV single-batch
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
// KQ + KQV multi-batch
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
} else if (src0->type == GGML_TYPE_F32) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {

#ifdef GGML_CUDA_FORCE_DMMV
const bool use_mul_mat_vec_q = false;
#else
Expand All @@ -7250,13 +7340,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
}
} else {
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
bool use_mul_mat_q = g_mul_mat_q && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);

// TODO: better way to determine availability of tensor cores
// currently fails for GeForce GTX 1660 which is TURING arch but does not have tensor cores
if (min_compute_capability >= CC_VOLTA && src1->ne[1] > 32) {
// when tensor cores are available, use them for large batch size
// when tensor cores are available, use them for large batch size
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
if (use_tensor_cores && min_compute_capability >= CC_VOLTA && src1->ne[1] > MMQ_MAX_BATCH_SIZE) {
use_mul_mat_q = false;
}

Expand Down Expand Up @@ -7614,10 +7702,6 @@ void ggml_cuda_set_main_device(const int main_device) {
}
}

void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
g_mul_mat_q = mul_mat_q;
}

void ggml_cuda_set_scratch_size(const size_t scratch_size) {
// this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously
// it still won't always work as expected, but it's better than nothing
Expand Down
2 changes: 0 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5959,8 +5959,6 @@ static int llama_decode_internal(
}
}

ggml_cuda_set_mul_mat_q(cparams.mul_mat_q);

// HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed
if (!lctx.embedding.empty()) {
embeddings->backend = GGML_BACKEND_CPU;
Expand Down
2 changes: 1 addition & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ extern "C" {
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model

// Keep the booleans together to avoid misalignment during copy-by-value.
bool mul_mat_q; // if true, use experimental mul_mat_q kerne 4B03 ls
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
bool f16_kv; // use fp16 for KV cache, fp32 otherwise
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool embedding; // embedding mode only
Expand Down
0