8000 Arm AArch64: optimized GEMV and GEMM kernels for q4_0_q8_0, and q8_0_q8_0 quantization by Dibakar · Pull Request #5780 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Arm AArch64: optimized GEMV and GEMM kernels for q4_0_q8_0, and q8_0_q8_0 quantization #5780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jul 10, 2024
Merged
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
002e36e
Arm AArch64: optimized GEMV and GEMM kernels for q4_0_q8_0, and q8_0_…
Dibakar Feb 28, 2024
340ef07
Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 qu…
Dibakar Apr 22, 2024
81215ff
Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 qu…
Dibakar Apr 23, 2024
6c8d826
Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 qu…
Dibakar Apr 25, 2024
43e1297
Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 qu…
Dibakar Apr 29, 2024
441ab64
Arm AArch64: add copyright claim only to ggml-aarch64.cpp and ggml-aa…
Dibakar Apr 29, 2024
8ee6779
Arm AArch64: minor code refactoring for rebase
Dibakar May 1, 2024
a657246
Arm AArch64: minor code refactoring for resolving a build issue with …
Dibakar May 16, 2024
746b57f
Arm AArch64: minor code refactoring to split the Q4_0_AARC64 type int…
Dibakar May 21, 2024
5d10c21
Arm AArch64: minor code change for resolving a build issue with serve…
Dibakar May 31, 2024
7ac03e5
retrigger checks
Dibakar May 31, 2024
e2c1c47
Arm AArch64: minor code changes for rebase
Dibakar Jun 5, 2024
79b6cdf
Arm AArch64: minor changes to skip the pr#7433 vec_dot code for arm c…
Dibakar Jun 14, 2024
3c1ad5f
Arm AArch64: remove stale LLAMA_QKK_64 from CMakeLists.txt and delete…
Dibakar Jun 14, 2024
a7055b7
Arm AArch64: add reference scalar gemm and gemv, and avoid dynamic me…
Dibakar Jun 18, 2024
cce236b
Arm AArch64: add multithreaded quantization support for the new types…
Dibakar Jun 19, 2024
7a70606
Arm AArch64: minor code refactoring
Dibakar Jun 19, 2024
ffbfabb
Arm AArch64: simplify logic for calling gemm and gemv functions in gg…
Dibakar Jun 23, 2024
cbbfd69
Arm AArch64: minimize changes in ggml_compute_forward_mul_mat
Dibakar Jun 26, 2024
3564644
Arm AArch64: minor code refactoring, and add reference scalar code to…
Dibakar Jul 3, 2024
110d143
Arm AArch64: minor code refactoring
Dibakar Jul 3, 2024
4ff0b22
Arm AArch64: minor code refactoring
Dibakar Jul 6, 2024
42724b4
Arm AArch64: minor code refactoring
Dibakar Jul 8, 2024
e5f4713
rebase on the latest master commit 3fd62a6 and adapt to the new direc…
Dibakar Jul 8, 2024
c2595d0
Arm AArch64: remove a redundant comment
Dibakar Jul 9, 2024
a7abb78
Arm AArch64: add pragma in ggml-aarch64.c to turn -Woverlength-string…
Dibakar Jul 9, 2024
0e84ef1
Arm AArch64: use __aarch64__ check to guard 64-bit neon kernels
Dibakar Jul 9, 2024
c653eb1
Arm AArch64: update docs/build.md README to include compile time flag…
Dibakar Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Arm AArch64: minor code refactoring
  • Loading branch information
Dibakar committed Jul 8, 2024
commit 4ff0b223c3d85b6fb0319302dcd71d2fdcdd94e1
34 changes: 32 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -12374,12 +12374,12 @@ UseGgmlGemm2:;
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);

if ((ggml_n_dims(src0) == 2) && gemv) {
const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
int64_t src0_start = (ith * ne01) / nth;
int64_t src0_end = ((ith + 1) * ne01) / nth;
src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
if (src0_start >= src0_end) return;

// If there are more than three rows in src1, use gemm; otherwise, use gemv.
Expand Down Expand Up @@ -12438,6 +12438,8 @@ static void ggml_compute_forward_mul_mat_id(
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
int64_t const matmul_num_cols = type_traits[type].ncols;
ggml_gemv_t const gemv = type_traits[type].gemv;

// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
Expand Down Expand Up @@ -12523,6 +12525,34 @@ static void ggml_compute_forward_mul_mat_id(
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows

if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
int64_t src0_cur_start = (ith * ne01) / nth;
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
if (src0_cur_start >= src0_cur_end) return;

for (int ir1 = 0; ir1 < nr1; ir1++) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
const int id = row_mapping.i1; // selected expert index

const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2; // row index in src1

const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row

const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12 * ne11) * row_size
: (i11 * nb11 + i12 * nb12));

gemv(ne00, (float *)((char * 53A5 ) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
(const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
}
continue;
}

// distribute the thread work across the inner or outer loop based on which one is larger

const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
Expand Down
0