8000 [inductor][cpp][gemm] improve thread blocking heuristics by jgong5 · Pull Request #131024 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor][cpp][gemm] improve thread blocking heuristics #131024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

jgong5
Copy link
Collaborator
@jgong5 jgong5 commented Jul 18, 2024

Stack from ghstack (oldest at bottom):

This PR improves the thread blocking heuristics to favor full occupancy as much as possible. Also, the "m x n" block size is made as squared as possible for better data reuse.

Take the shape M=20000, N=64, K=128 as an example, the original heuristics couldn't use up all the threads when the number of threads is large, say 60:
AUTOTUNE linear_unary(200000x128, 64x128, 64)
_linear_pointwise 0.1010 ms 100.0%
cpp_packed_gemm_0 0.8303 ms 12.2%
0722 02:26:39.220660 302553 torch/_inductor/codegen/cpp_gemm_template.py:503] [0/0] Register blocking: GemmBlocking(block_m=32, block_n=32, block_k=32)
V0722 02:26:39.221042 302553 torch/_inductor/codegen/cpp_gemm_template.py:507] [0/0] Cache blocking: GemmBlocking(block_m=625, block_n=1, block_k=4)
V0722 02:26:39.221118 302553 torch/_inductor/codegen/cpp_gemm_template.py:509] [0/0] Thread blocking: GemmBlocking(block_m=625, block_n=1, block_k=4)
V0722 02:26:39.221252 302553 torch/_inductor/codegen/cpp_gemm_template.py:526] [0/0] Number of threads: 60, occupancy: (10, 2, 1)

After this PR:
AUTOTUNE linear_unary(200000x128, 64x128, 64)
_linear_pointwise 0.1143 ms 100.0%
cpp_packed_gemm_0 0.1228 ms 93.1%
V0722 02:29:49.261794 304201 torch/_inductor/codegen/cpp_gemm_template.py:309] [0/0] Register blocking: GemmBlocking(block_m=32, block_n=32, block_k=32)
V0722 02:29:49.262860 304201 torch/_inductor/codegen/cpp_gemm_template.py:313] [0/0] Cache blocking: GemmBlocking(block_m=64, block_n=1, block_k=8)
V0722 02:29:49.262951 304201 torch/_inductor/codegen/cpp_gemm_template.py:315] [0/0] Thread blocking: GemmBlocking(block_m=69, block_n=79, block_k=8)
V0722 02:29:49.263075 304201 torch/_inductor/codegen/cpp_gemm_template.py:332] [0/0] Number of threads: 60, occupancy: (15, 4, 1)

cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
Copy link
pytorch-bot bot commented Jul 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131024

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0cb8b78 with merge base 1614891 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Jiong Gong added 4 commits July 20, 2024 07:51
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@jgong5 jgong5 marked this pull request as ready for review July 22, 2024 06:36
block_n_size = blocking.block_n * register_blocking.block_n
best_block_m_size = best_blocking.block_m * register_blocking.block_m
best_block_n_size = best_blocking.block_n * register_blocking.block_n
if block_m_size + block_n_size < best_block_m_size + best_block_n_size:
Copy link
Collaborator
@leslie-fang-intel leslie-fang-intel Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the tail thread block has more similar size as the main thread blocks? We prefer this case for better utilization of tail thread.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not about the tails but the main blocks.

)
k = self.num_threads // (
self.num_threads // math.ceil(k_blocks / thread_blocking.block_k)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to do the floor div twice here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

[ghstack-poisoned]
@jgong5 jgong5 added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Jul 23, 2024
@jgong5
Copy link
Collaborator Author
jgong5 commented Jul 23, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@jgong5
Copy link
Collaborator Author
jgong5 commented Jul 24, 2024
< 8000 table class="d-block user-select-contain" data-paste-markdown-skip>

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
)

This PR improves the thread blocking heuristics to favor full occupancy as much as possible. Also, the "m x n" block size is made as squared as possible for better data reuse.

Take the shape M=20000, N=64, K=128 as an example, the original heuristics couldn't use up all the threads when the number of threads is large, say 60:
AUTOTUNE linear_unary(200000x128, 64x128, 64)
  _linear_pointwise 0.1010 ms 100.0%
  cpp_packed_gemm_0 0.8303 ms 12.2%
0722 02:26:39.220660 302553 torch/_inductor/codegen/cpp_gemm_template.py:503] [0/0] Register blocking: GemmBlocking(block_m=32, block_n=32, block_k=32)
V0722 02:26:39.221042 302553 torch/_inductor/codegen/cpp_gemm_template.py:507] [0/0] Cache blocking: GemmBlocking(block_m=625, block_n=1, block_k=4)
V0722 02:26:39.221118 302553 torch/_inductor/codegen/cpp_gemm_template.py:509] [0/0] Thread blocking: GemmBlocking(block_m=625, block_n=1, block_k=4)
V0722 02:26:39.221252 302553 torch/_inductor/codegen/cpp_gemm_template.py:526] [0/0] Number of threads: 60, occupancy: (10, 2, 1)

After this PR:
AUTOTUNE linear_unary(200000x128, 64x128, 64)
  _linear_pointwise 0.1143 ms 100.0%
  cpp_packed_gemm_0 0.1228 ms 93.1%
V0722 02:29:49.261794 304201 torch/_inductor/codegen/cpp_gemm_template.py:309] [0/0] Register blocking: GemmBlocking(block_m=32, block_n=32, block_k=32)
V0722 02:29:49.262860 304201 torch/_inductor/codegen/cpp_gemm_template.py:313] [0/0] Cache blocking: GemmBlocking(block_m=64, block_n=1, block_k=8)
V0722 02:29:49.262951 304201 torch/_inductor/codegen/cpp_gemm_template.py:315] [0/0] Thread blocking: GemmBlocking(block_m=69, block_n=79, block_k=8)
V0722 02:29:49.263075 304201 torch/_inductor/codegen/cpp_gemm_template.py:332] [0/0] Number of threads: 60, occupancy: (15, 4, 1)

Pull Request resolved: pytorch#131024
Approved by: https://github.com/leslie-fang-intel, https://github.com/chunyuan-w
pytorchmergebot pushed a commit that referenced this pull request Jul 25, 2024
This PR provides the initial support for k-slicing (i.e. parallel reduction along k-dim) of CPP GEMM template. Only static shapes are supported now. When k-slicing is enabled, there would be extra temporary buffers allocated to hold the intermediate results and an extra barrier after initial GEMM compute by each thread, i.e. each thread first stores the GEMM result to temporary accumulation buffers (pointed by `local_buf_ptrs` which is an array of pointers pointing to accumulation buffers), followed by a reduction along k-slices, epilogue computes and store to the final output `Y`. In each k-slicing thread group, the reduction along k-slices and epilogue computes are conducted in parallel along M-dim. The algorithm is designed to reduce the synchronization overhead as much as possible.

The k-slicing is enabled when blocking on M and N is unable to occupy all threads. Since k-slicing doesn't always bring benefit, an extra configuration is added to enable it (disable by default). We need to identify a good heuristics in the future to enable k-slicing by default.

Performance numbers with 64x4096x64, 64x10000x64, 64x20000x64 as examples on 60-core SPR as examples. As you can see, the perf of k-slicing is only better than non-k-slicing when K is large enough.

Without k-slicing
AUTOTUNE linear_unary(64x4096, 64x4096, 64)
  cpp_packed_gemm_0 0.0108 ms 100.0%
  _linear_pointwise 0.0431 ms 25.1%

AUTOTUNE linear_unary(64x10000, 64x10000, 64)
  cpp_packed_gemm_0 0.0272 ms 100.0%
  _linear_pointwise 0.0892 ms 30.5%

AUTOTUNE linear_unary(64x20000, 64x20000, 64)
  cpp_packed_gemm_0 0.0781 ms 100.0%
  _linear_pointwise 0.1693 ms 46.1%

With k-slicing:
AUTOTUNE linear_unary(64x4096, 64x4096, 64)
  cpp_packed_gemm_0 0.0260 ms 100.0%
  _linear_pointwise 0.0444 ms 58.5%

AUTOTUNE linear_unary(64x10000, 64x10000, 64)
  cpp_packed_gemm_0 0.0275 ms 100.0%
  _linear_pointwise 0.0893 ms 30.8%

AUTOTUNE linear_unary(64x20000, 64x20000, 64)
  cpp_packed_gemm_0 0.0284 ms 100.0%
  _linear_pointwise 0.1686 ms 16.8%

Pull Request resolved: #130821
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
ghstack dependencies: #131024
@github-actions github-actions bot deleted the gh/jgong5/61/head branch August 23, 2024 02:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0