-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[inductor][cpp][gemm] improve thread blocking heuristics #131024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131024
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0cb8b78 with merge base 1614891 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
block_n_size = blocking.block_n * register_blocking.block_n | ||
best_block_m_size = best_blocking.block_m * register_blocking.block_m | ||
best_block_n_size = best_blocking.block_n * register_blocking.block_n | ||
if block_m_size + block_n_size < best_block_m_size + best_block_n_size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean the tail thread block has more similar size as the main thread blocks? We prefer this case for better utilization of tail thread.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not about the tails but the main blocks.
) | ||
k = self.num_threads // ( | ||
self.num_threads // math.ceil(k_blocks / thread_blocking.block_k) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need to do the floor div twice here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
@pytorchbot merge
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
) This PR improves the thread blocking heuristics to favor full occupancy as much as possible. Also, the "m x n" block size is made as squared as possible for better data reuse. Take the shape M=20000, N=64, K=128 as an example, the original heuristics couldn't use up all the threads when the number of threads is large, say 60: AUTOTUNE linear_unary(200000x128, 64x128, 64) _linear_pointwise 0.1010 ms 100.0% cpp_packed_gemm_0 0.8303 ms 12.2% 0722 02:26:39.220660 302553 torch/_inductor/codegen/cpp_gemm_template.py:503] [0/0] Register blocking: GemmBlocking(block_m=32, block_n=32, block_k=32) V0722 02:26:39.221042 302553 torch/_inductor/codegen/cpp_gemm_template.py:507] [0/0] Cache blocking: GemmBlocking(block_m=625, block_n=1, block_k=4) V0722 02:26:39.221118 302553 torch/_inductor/codegen/cpp_gemm_template.py:509] [0/0] Thread blocking: GemmBlocking(block_m=625, block_n=1, block_k=4) V0722 02:26:39.221252 302553 torch/_inductor/codegen/cpp_gemm_template.py:526] [0/0] Number of threads: 60, occupancy: (10, 2, 1) After this PR: AUTOTUNE linear_unary(200000x128, 64x128, 64) _linear_pointwise 0.1143 ms 100.0% cpp_packed_gemm_0 0.1228 ms 93.1% V0722 02:29:49.261794 304201 torch/_inductor/codegen/cpp_gemm_template.py:309] [0/0] Register blocking: GemmBlocking(block_m=32, block_n=32, block_k=32) V0722 02:29:49.262860 304201 torch/_inductor/codegen/cpp_gemm_template.py:313] [0/0] Cache blocking: GemmBlocking(block_m=64, block_n=1, block_k=8) V0722 02:29:49.262951 304201 torch/_inductor/codegen/cpp_gemm_template.py:315] [0/0] Thread blocking: GemmBlocking(block_m=69, block_n=79, block_k=8) V0722 02:29:49.263075 304201 torch/_inductor/codegen/cpp_gemm_template.py:332] [0/0] Number of threads: 60, occupancy: (15, 4, 1) Pull Request resolved: pytorch#131024 Approved by: https://github.com/leslie-fang-intel, https://github.com/chunyuan-w
This PR provides the initial support for k-slicing (i.e. parallel reduction along k-dim) of CPP GEMM template. Only static shapes are supported now. When k-slicing is enabled, there would be extra temporary buffers allocated to hold the intermediate results and an extra barrier after initial GEMM compute by each thread, i.e. each thread first stores the GEMM result to temporary accumulation buffers (pointed by `local_buf_ptrs` which is an array of pointers pointing to accumulation buffers), followed by a reduction along k-slices, epilogue computes and store to the final output `Y`. In each k-slicing thread group, the reduction along k-slices and epilogue computes are conducted in parallel along M-dim. The algorithm is designed to reduce the synchronization overhead as much as possible. The k-slicing is enabled when blocking on M and N is unable to occupy all threads. Since k-slicing doesn't always bring benefit, an extra configuration is added to enable it (disable by default). We need to identify a good heuristics in the future to enable k-slicing by default. Performance numbers with 64x4096x64, 64x10000x64, 64x20000x64 as examples on 60-core SPR as examples. As you can see, the perf of k-slicing is only better than non-k-slicing when K is large enough. Without k-slicing AUTOTUNE linear_unary(64x4096, 64x4096, 64) cpp_packed_gemm_0 0.0108 ms 100.0% _linear_pointwise 0.0431 ms 25.1% AUTOTUNE linear_unary(64x10000, 64x10000, 64) cpp_packed_gemm_0 0.0272 ms 100.0% _linear_pointwise 0.0892 ms 30.5% AUTOTUNE linear_unary(64x20000, 64x20000, 64) cpp_packed_gemm_0 0.0781 ms 100.0% _linear_pointwise 0.1693 ms 46.1% With k-slicing: AUTOTUNE linear_unary(64x4096, 64x4096, 64) cpp_packed_gemm_0 0.0260 ms 100.0% _linear_pointwise 0.0444 ms 58.5% AUTOTUNE linear_unary(64x10000, 64x10000, 64) cpp_packed_gemm_0 0.0275 ms 100.0% _linear_pointwise 0.0893 ms 30.8% AUTOTUNE linear_unary(64x20000, 64x20000, 64) cpp_packed_gemm_0 0.0284 ms 100.0% _linear_pointwise 0.1686 ms 16.8% Pull Request resolved: #130821 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel ghstack dependencies: #131024
Stack from ghstack (oldest at bottom):
This PR improves the thread blocking heuristics to favor full occupancy as much as possible. Also, the "m x n" block size is made as squared as possible for better data reuse.
Take the shape M=20000, N=64, K=128 as an example, the original heuristics couldn't use up all the threads when the number of threads is large, say 60:
AUTOTUNE linear_unary(200000x128, 64x128, 64)
_linear_pointwise 0.1010 ms 100.0%
cpp_packed_gemm_0 0.8303 ms 12.2%
0722 02:26:39.220660 302553 torch/_inductor/codegen/cpp_gemm_template.py:503] [0/0] Register blocking: GemmBlocking(block_m=32, block_n=32, block_k=32)
V0722 02:26:39.221042 302553 torch/_inductor/codegen/cpp_gemm_template.py:507] [0/0] Cache blocking: GemmBlocking(block_m=625, block_n=1, block_k=4)
V0722 02:26:39.221118 302553 torch/_inductor/codegen/cpp_gemm_template.py:509] [0/0] Thread blocking: GemmBlocking(block_m=625, block_n=1, block_k=4)
V0722 02:26:39.221252 302553 torch/_inductor/codegen/cpp_gemm_template.py:526] [0/0] Number of threads: 60, occupancy: (10, 2, 1)
After this PR:
AUTOTUNE linear_unary(200000x128, 64x128, 64)
_linear_pointwise 0.1143 ms 100.0%
cpp_packed_gemm_0 0.1228 ms 93.1%
V0722 02:29:49.261794 304201 torch/_inductor/codegen/cpp_gemm_template.py:309] [0/0] Register blocking: GemmBlocking(block_m=32, block_n=32, block_k=32)
V0722 02:29:49.262860 304201 torch/_inductor/codegen/cpp_gemm_template.py:313] [0/0] Cache blocking: GemmBlocking(block_m=64, block_n=1, block_k=8)
V0722 02:29:49.262951 304201 torch/_inductor/codegen/cpp_gemm_template.py:315] [0/0] Thread blocking: GemmBlocking(block_m=69, block_n=79, block_k=8)
V0722 02:29:49.263075 304201 torch/_inductor/codegen/cpp_gemm_template.py:332] [0/0] Number of threads: 60, occupancy: (15, 4, 1)
cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang