8000 Update the heuristic for AArch64 bmm/baddbmm (#149122) · pytorch/pytorch@d759a51 · GitHub
[go: up one dir, main page]

Skip to content

Commit d759a51

Browse files
michalowski-armpytorchmergebot
authored andcommitted
Update the heuristic for AArch64 bmm/baddbmm (#149122)
Updates heuristic for bmm/baddbmm and consolidates all heuristic logic in a single location - The goal of the consolidation is to improve maintainability and readability of the heuristic logic. Instead of different parts scattered across two files, this patch centralizes everything inside `Matmul.cpp`, where there already exists heuristic-based selection for mkldnn. - The logic of the check itself doesn't change (existing code is reused where possible) but a separate heuristic threshold for bmm/baddbmm is introduced based on newer, benchmarking data. Use the script below to see the performance improvement for bmm from the new heuristic: ``` import torch import time # Set below to True to use cases selected by only one of the hueristics. USE_ONLY_DIVERGENT_TEST_CASES = True BATCH_SIZES = [ 1, 8, 32, 64, 128, 256 ] M_DIMS = [ 4, 8, 16, 32, 64, 256, 512 ] N_DIMS = [ 4, 8, 16, 32, 64, 256, 512 ] K_DIMS = [ 4, 8, 16, 32, 64, 256, 512 ] ITERS = 50 def old_heuristic(m, n, k): is_above_min_dims = m > 8 and n > 8 and k > 8 is_above_min_size = m*n*k > 8_192 return is_above_min_dims and is_above_min_size def new_heuristic(b, m, n, k): return b*b*m*n*k >= 4_194_304 def generate_test_cases(): test_cases = [] for b in BATCH_SIZES: for m in M_DIMS: for n in N_DIMS: for k in K_DIMS: if USE_ONLY_DIVERGENT_TEST_CASES: if old_heuristic(m, n, k) != new_heuristic(b, m, n, k): test_cases.append([b, m, n, k]) else: test_cases.append([b, m, n, k]) return test_cases def test(x, y): for _ in range(5): torch.bmm(x, y) perf = 0.0 for _ in range(ITERS): start = time.time() torch.bmm(x, y) end = time.time() perf += (end - start) / ITERS return perf def main(): print(f"{'b':<10}{'m':<10}{'n':<10}{'k':<10}{'time (s)':10}") cumulative_mean_time = 0.0 for b, m, n, k in generate_test_cases(): mean_time = test(torch.rand(b, m, n), torch.rand(b, n, k)) cumulative_mean_time += mean_time print(f"{b:<10}{m:<10}{n:<10}{k:<10}{mean_time:10.3e}") print(f"Cumulative mean time = {cumulative_mean_time:.4f} s") if __name__ == "__main__": main() ``` From the script we see that cumulative mean time from all test cases (at 16 threads) is: - 1.6195 s for the old heuristic - 0.7012 s for the new heuristic Pull Request resolved: #149122 Approved by: https://github.com/fadara01, https://github.com/aditew01, https://github.com/malfet
1 parent e8662e8 commit d759a51

File tree

2 files changed

+54
-39
lines changed

2 files changed

+54
-39
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,41 +1360,6 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
13601360
#endif
13611361

13621362

1363-
static inline int64_t get_mkldnn_matmul_min_dim() {
1364-
static auto value = [&] {
1365-
const int64_t default_min_dim = [&] {
1366-
// Minimum dimension requirement for MKLDNN; derived based on experiments.
1367-
//it's enabled on all Neoverse cpus.
1368-
return is_arm_neoverse() ? 8 : 0;
1369-
}();
1370-
const auto value = c10::utils::get_env("TORCH_MKLDNN_MATMUL_MIN_DIM");
1371-
return value.has_value() ? std::stoi(value.value()) : default_min_dim;
1372-
}();
1373-
return value;
1374-
}
1375-
1376-
1377-
static inline int64_t get_mkldnn_matmul_min_size() {
1378-
static auto value = [&] {
1379-
const int64_t default_min_size = [&] {
1380-
// Minimum size requirement for MKLDNN; derived based on experiments.
1381-
// it's enabled on all Neoverse cpus.
1382-
return is_arm_neoverse() ? 8 * 1024 : 0;
1383-
}();
1384-
const auto value = c10::utils::get_env("TORCH_MKLDNN_MATMUL_MIN_SIZE");
1385-
return value.has_value() ? std::stoi(value.value()) : default_min_size;
1386-
}();
1387-
return value;
1388-
}
1389-
1390-
1391-
static inline bool apply_mkldnn_matmul_heur(int64_t m, int64_t k, int64_t n) {
1392-
const int64_t min_dim = get_mkldnn_matmul_min_dim();
1393-
const int64_t min_size = get_mkldnn_matmul_min_size();
1394-
return at::globalContext().userEnabledMkldnn() && m > min_dim && k > min_dim && n > min_dim && m * k * n > min_size;
1395-
}
1396-
1397-
13981363
static void addmm_impl_cpu_(
13991364
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
14001365
TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
@@ -1514,8 +1479,7 @@ static void addmm_impl_cpu_(
15141479
// that will call then into Arm® Compute Library (ACL) GEMM kernel and also
15151480
// additionally have support for running kernel with BF16 instructions
15161481
if (transpose_c) {
1517-
bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
1518-
if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
1482+
if (use_mkldnn_matmul(b, a, c) && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
15191483
try {
15201484
mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
15211485
// We have dispatched to ACL GEMM for single precision float
@@ -1771,8 +1735,7 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
17711735
(strides[1] == 1 && (sizes[2] == 1 || strides[2] >= sizes[1]));
17721736
};
17731737

1774-
bool apply_heur = apply_mkldnn_matmul_heur(batch1.sizes()[1], batch1.sizes()[2], batch2.sizes()[2]);
1775-
if (apply_heur && use_mkldnn_matmul(batch1, batch2, self_or_result)) {
1738+
if (use_mkldnn_matmul(batch1, batch2, self_or_result)) {
17761739
try {
17771740
mkldnn_matmul(batch1, batch2, self_or_result, beta.to<float>(), alpha.to<float>());
17781741
return;

aten/src/ATen/native/mkldnn/Matmul.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,42 @@ void mkldnn_matmul(
322322

323323
}
324324

325+
#if AT_MKLDNN_ACL_ENABLED()
326+
// Experimentally derived heuristics for MKLDNN+ACL on NEOVERSE cores
327+
static inline int64_t get_mkldnn_acl_addmm_min_dim() {
328+
static auto value = [&] {
329+
const int64_t default_min_dim = [&] {
330+
return is_arm_neoverse() ? 8 : 0;
331+
}();
332+
const char* ptr = std::getenv("TORCH_MKLDNN_ADDMM_MIN_DIM");
333+
return ptr != nullptr ? std::atoi(ptr) : default_min_dim;
334+
}();
335+
return value;
336+
}
337+
338+
static inline int64_t get_mkldnn_acl_addmm_min_size() {
339+
static auto value = [&] {
340+
const int64_t default_min_size = [&] {
341+
return is_arm_neoverse() ? 8 * 1024 : 0;
342+
}();
343+
const char* ptr = std::getenv("TORCH_MKLDNN_ADDMM_MIN_SIZE");
344+
return ptr != nullptr ? std::atoi(ptr) : default_min_size;
345+
}();
346+
return value;
347+
}
348+
349+
static inline int64_t get_mkldnn_acl_bmm_baddbmm_threshold() {
350+
static auto value = [&] {
351+
const int64_t default_threshold = [&] {
352+
return is_arm_neoverse() ? 1L << 22 : 0;
353+
}();
354+
const char* ptr = std::getenv("TORCH_MKLDNN_BMM_BADDBMM_THRESHOLD");
355+
return ptr != nullptr ? std::atoi(ptr) : default_threshold;
356+
}();
357+
return value;
358+
}
359+
#endif
360+
325361
static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
326362
// if dim = 2, mat1's size = (m * n), mat2's size = (n * k)
327363
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
@@ -336,10 +372,26 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
336372
return mat1.size(0) * mat1.size(1) > mkldnn_gemm_min_size;
337373
} else if (mat2.dim() == 2 && mat2.dim() == 2) {
338374
// aten::addmm
375+
#if AT_MKLDNN_ACL_ENABLED()
376+
const int64_t mkldnn_acl_addmm_min_dim = get_mkldnn_acl_addmm_min_dim();
377+
const int64_t mkldnn_acl_addmm_min_size = get_mkldnn_acl_addmm_min_size();
378+
// M > MIN_DIM and N > MIN_DIM and K > MIN_DIM and M*N*K > MIN_SIZE
379+
return mat1.size(0) > mkldnn_acl_addmm_min_dim
380+
&& mat1.size(1) > mkldnn_acl_addmm_min_dim
381+
&& mat2.size(1) > mkldnn_acl_addmm_min_dim
382+
&& mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_acl_addmm_min_size;
383+
#else
339384
return mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_gemm_min_size;
385+
#endif
340386
} else {
341387
// aten::bmm, aten::baddbmm
388+
#if AT_MKLDNN_ACL_ENABLED()
389+
const int64_t mkldnn_acl_bmm_baddbmm_threshold = get_mkldnn_acl_bmm_baddbmm_threshold();
390+
// BATCH_SIZE^2 * M * N * K >= THRESHOLD
391+
return mat1.size(0) * mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) >= mkldnn_acl_bmm_baddbmm_threshold;
392+
#else
342393
return mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) > mkldnn_gemm_min_size;
394+
#endif
343395
}
344396
}
345397

0 commit comments

Comments
 (0)
0