8000 SOTA 2-bit quants by ikawrakow · Pull Request #4773 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter
Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
iq2_xxs: even faster Metal dot product
TG-128 is now 54.1 t/s.

Strangely enough, putting the signs lookup table
into shared memory has a bigger impact than the
grid values being in shared memory.
  • Loading branch information
Iwan Kawrakow committed Jan 8, 2024
commit 065cc8cb474be7945d2997047dd926c644899cc5
4 changes: 2 additions & 2 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,7 @@ bool ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS) {
[encoder setThreadgroupMemoryLength:256*8 atIndex:0];
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q4_K) {
Expand Down Expand Up @@ -1981,7 +1981,7 @@ bool ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS) {
[encoder setThreadgroupMemoryLength:256*8 atIndex:0];
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_Q4_K) {
Expand Down
10 changes: 7 additions & 3 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3596,10 +3596,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
const int nb32 = nb * (QK_K / 32);

threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
{
const int nval = 4;
const int pos = (32*sgitg + tiisg)*nval;
int nval = 4;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i];
nval = 2;
pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}

Expand Down Expand Up @@ -3631,7 +3635,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
float sum = 0;
for (int l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
for (int j = 0; j < 8; ++j) {
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
Expand Down
0