8000 Introduce New Lookup-Table(LUT)-Based Matrix Multiplication Method (TMAC) by QingtaoLi1 · Pull Request #13206 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Introduce New Lookup-Table(LUT)-Based Matrix Multiplication Method (TMAC) #13206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use multi-thread to accelerate tensor transformation. ~70% time reduc…
…tion.
  • Loading branch information
QingtaoLi1 committed May 15, 2025
commit ea2876fac802a7528e9436a836b4317a7cf7aece
50 changes: 48 additions & 2 deletions ggml/src/ggml-cpu/tmac/lut_mul_mat.cpp
10000
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <fstream>
#include <stdexcept>
#include <string>
#include <thread>
#include <unordered_map>

#define GGML_COMMON_IMPL_CPP
Expand Down Expand Up @@ -773,6 +774,9 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
// for fast testing
// #define TMAC_EMPTY_WEIGHTS
#ifndef TMAC_EMPTY_WEIGHTS
std::vector<std::thread> threads;
const int n_threads = std::thread::hardware_concurrency();

// TODO: optimize to accelerate weights loading
uint8_t * buf2 = new uint8_t[m * k / g];
memset(buf2, 0, m * k / g);
Expand All @@ -782,7 +786,9 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
// # (M // bits, K, bits) -> (M // bits, bits, K) -> (M // bits, bits, K // g, g) -> (M // bits, bits, K // g)
// w = w.transpose(0, 2, 1).reshape(M // bits, bits, K // g, g)
// w = sum([(w[:, :, :, ig] << ig) for ig in range(g)])
for (int im = 0; im < m / bits; im++) {
threads.reserve(n_threads);
auto parallel_worker_buf2 = [&](size_t start_index, size_t end_index) {
for (int im = start_index; im < end_index; im++) {
for (int ik = 0; ik < k; ik++) {
uint8_t v;
if (tensor->type == GGML_TYPE_Q4_0) {
Expand All @@ -808,6 +814,25 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
}
}
}
};

size_t start_index = 0;
size_t chunk_size = m / bits / n_threads;
for (size_t i = 0; i < n_threads; ++i) {
size_t end_index = (i == n_threads - 1) ? m / bits : start_index + chunk_size;

// Create and launch a thread
threads.emplace_back(parallel_worker_buf2,
start_index,
end_index); // Pass the mutex array by reference

start_index = end_index;
}
// Wait for all threads to complete
for (std::thread& t : threads) {
t.join();
}
threads.clear();

// # 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31
// # for bits=3
Expand Down Expand Up @@ -843,7 +868,9 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
int c2_fac1 = bm / mgroup * c2_fac2;
int c2_fac0 = k / g / kfactor * c2_fac1;

for (int im = 0; im < m / bits; im++) {
threads.reserve(n_threads);
auto parallel_worker_qweights = [&](size_t start_index, size_t end_index) {
for (int im = start_index; im < end_index; im++) {
for (int ib = 0; ib < bits; ib++) {
for (int ik = 0; ik < k / g; ik++) {
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
Expand Down Expand Up @@ -881,6 +908,25 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
}
}
}
};

start_index = 0;
chunk_size = m / bits / n_threads;
for (size_t i = 0; i < n_threads; ++i) {
size_t end_index = (i == n_threads - 1) ? m / bits : start_index + chunk_size;

// Create and launch a thread
threads.emplace_back(parallel_worker_qweights,
start_index,
end_index); // Pass the mutex array by reference

start_index = end_index;
}
// Wait for all threads to complete
for (std::thread& t : threads) {
t.join();
}
threads.clear();

const float * int_n_scales = (const float * ) ((const uint8_t *) origin_data + k * m / 8);
const float * int_n_zero_points = int_n_scales + scales_size / 2;
Expand Down
0