8000 [inductor][cpp][gemm] improve large bs perf with better cache blocking by jgong5 · Pull Request #132729 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor][cpp][gemm] improve large bs perf with better cache blocking #132729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

jgong5
Copy link
Collaborator
@jgong5 jgong5 commented Aug 6, 2024

Stack from ghstack (oldest at bottom):

Improve the cache blocking by reducing Mc_blocks to make A reside in L2 and reused by B as much as possible. This improves large bs perf for both scenarios: 1) N is large and K is of medium sizes; 2) K is large. Different strategies are used to handle these scenarios. Check the notes in get_cache_blocking in the changes.

Measured with 56-core Intel (R) Xeon (R) CPU Max 9480, jemalloc 5.1 and intel omp, bf16. Run with code cache of B matrix (weights).

Model Shapes Before Optimization After Optimization Speedup onednn linear Speedup over onednn
M=1024, N=12288, K=4096 (Llama2-8b) 5.69 ms 3.71 ms 1.53 4.53 ms 1.22
M=1024, N=4096, K=4096 (Llama2-8b) 1.69 ms 1.63 ms 1.04 2.05 ms 1.26
M=1024, N=22016, K=4096 (Llama2-8b) 10.32 ms 6.57 ms 1.57 8.46 ms 1.29
M=1024, N=4096, K=11008 (Llama2-8b) 5.21 ms 3.26 ms 1.60 4.65 ms 1.43
M=1024, N=5120, K=4096 (Llama3-8b) 1.99 ms 1.78 ms 1.12 2.31 ms 1.30
M=1024, N=28672, K=4096 (Llama3-8b) 13.41 ms 8.56 ms 1.57 10.96 ms 1.28
M=1024, N=4096, K=14336 (Llama3-8b) 6.93 ms 4.31 ms 1.61 6.24 ms 1.45

cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
Copy link
pytorch-bot bot commented Aug 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/132729

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 37b7b4a with merge base f951fcd (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
@jgong5 jgong5 added the topic: not user facing topic category label Aug 6, 2024
[ghstack-poisoned]
jgong5 pushed a commit that referenced this pull request Aug 13, 2024
@jgong5 jgong5 added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 13, 2024
jgong5 pushed a commit that referenced this pull request Aug 13, 2024
# MxNxK dims respectively. The blockings are separated by comma and the unit is
# the number of register blocks.
# For example, "4,1,10" means 4 register blocks on M, 1 on N and 10 on K respectively.
gemm_cache_blocking = os.environ.get("TORCHINDUCTOR_CPP_GEMM_CACHE_BLOCKING", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test that enables this config via @config.patch(gemm_cache_blocking="...")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Added.

[ghstack-poisoned]
const int64_t n_size = n_end - n_start;
const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); // FIXME: maybe exceeding N?
Copy link
Collaborator
@leslie-fang-intel leslie-fang-intel Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we always do padded_n, I guess it will not exceed N. Maybe we can remove this FIXME.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Updated the comment.

# The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2,
# assuming Mc == Nc for good data reuse.
M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8
if M_max < Mc_blocks * M0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be an assert theoretically? But considering we use some approximated calculation. I guess it should be fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not an assertion. When it is false, we use the default one which is Mt_blocks.

[ghstack-poisoned]
@jgong5
Copy link
Collaborator Author
jgong5 commented Aug 16, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Aug 17, 2024
…ing etc. (#133312)

Indent the template instructions separately from the generated code, for readability. Also, renaming M0,N0,K0 to Mr,Nr,Kr ("r" meaning "register") to consistent naming.

Pull Request resolved: #133312
Approved by: https://github.com/Skylion007, https://github.com/leslie-fang-intel
ghstack dependencies: #132729, #132730
pytorchmergebot pushed a commit that referenced this pull request Sep 9, 2024
…) (#135438)

Fix #134686.

PR #132729 makes GEMM template faster for one of the GEMMs in xcit_large_24_p8_224:
SingleProcess AUTOTUNE benchmarking takes 1.7088 seconds and 1.9207 seconds precompiling
AUTOTUNE linear_unary(12544x3072, 768x3072, 768)
  cpp_packed_gemm_2 2.9371 ms 100.0%
  _linear_pointwise 3.1584 ms 93.0%

But it is slower than Aten in the e2e run due to different cache behavior. The access to the input data (12544x3072) is LLC latency bound and bottlenecks seen due to the memory synchronization (data transfers and coherence updates across processors). This PR tries to mitigate the problem by cooperatively loading different chunks of input data from different processors that share the input data.

Pull Request resolved: #135438
Approved by: https://github.com/leslie-fang-intel
yushangdi pushed a commit that referenced this pull request Sep 12, 2024
…) (#135438)

Fix #134686.

PR #132729 makes GEMM template faster for one of the GEMMs in xcit_large_24_p8_224:
SingleProcess AUTOTUNE benchmarking takes 1.7088 seconds and 1.9207 seconds precompiling
AUTOTUNE linear_unary(12544x3072, 768x3072, 768)
  cpp_packed_gemm_2 2.9371 ms 100.0%
  _linear_pointwise 3.1584 ms 93.0%

But it is slower than Aten in the e2e run due to different cache behavior. The access to the input data (12544x3072) is LLC latency bound and bottlenecks seen due to the memory synchronization (data transfers and coherence updates across processors). This PR tries to mitigate the problem by cooperatively loading different chunks of input data from different processors that share the input data.

Pull Request resolved: #135438
Approved by: https://github.com/leslie-fang-intel
malfet pushed a commit to aditew01/pytorch that referenced this pull request Sep 13, 2024
pytorch#132729)

Improve the cache blocking by reducing Mc_blocks to make A reside in L2 and reused by B as much as possible. This improves large bs perf for both scenarios: 1) N is large and K is of medium sizes; 2) K is large. Different strategies are used to handle these scenarios. Check the notes in `get_cache_blocking` in the changes.

Measured with 56-core Intel (R) Xeon (R) CPU Max 9480, jemalloc 5.1 and intel omp, bf16. Run with code cache of B matrix (weights).

Model Shapes | Before Optimization | After Optimization | Speedup | onednn linear | Speedup over onednn
-- | -- | -- | -- | -- | --
M=1024, N=12288, K=4096 (Llama2-8b) | 5.69 ms | 3.71 ms | 1.53 | 4.53 ms | 1.22
M=1024, N=4096, K=4096 (Llama2-8b) | 1.69 ms | 1.63 ms | 1.04 | 2.05 ms | 1.26
M=1024, N=22016, K=4096 (Llama2-8b) | 10.32 ms | 6.57 ms | 1.57 | 8.46 ms | 1.29
M=1024, N=4096, K=11008 (Llama2-8b) | 5.21 ms | 3.26 ms | 1.60 | 4.65 ms | 1.43
M=1024, N=5120, K=4096 (Llama3-8b) | 1.99 ms | 1.78 ms | 1.12 | 2.31 ms | 1.30
M=1024, N=28672, K=4096 (Llama3-8b) | 13.41 ms | 8.56 ms | 1.57 | 10.96 ms | 1.28
M=1024, N=4096, K=14336 (Llama3-8b) | 6.93 ms | 4.31 ms | 1.61 | 6.24 ms | 1.45

Pull Request resolved: pytorch#132729
Approved by: https://github.com/leslie-fang-intel, https://github.com/chunyuan-w, https://github.com/jansel
malfet pushed a commit to aditew01/pytorch that referenced this pull request Sep 13, 2024
malfet pushed a commit to aditew01/pytorch that referenced this pull request Sep 13, 2024
…ing etc. (pytorch#133312)

Indent the template instructions separately from the generated code, for readability. Also, renaming M0,N0,K0 to Mr,Nr,Kr ("r" meaning "register") to consistent naming.

Pull Request resolved: pytorch#133312
Approved by: https://github.com/Skylion007, https://github.com/leslie-fang-intel
ghstack dependencies: pytorch#132729, pytorch#132730
tolleybot pushed a commit to tolleybot/pytorch that referenced this pull request Sep 14, 2024
…ch#134686) (pytorch#135438)

Fix pytorch#134686.

PR pytorch#132729 makes GEMM template faster for one of the GEMMs in xcit_large_24_p8_224:
SingleProcess AUTOTUNE benchmarking takes 1.7088 seconds and 1.9207 seconds precompiling
AUTOTUNE linear_unary(12544x3072, 768x3072, 768)
  cpp_packed_gemm_2 2.9371 ms 100.0%
  _linear_pointwise 3.1584 ms 93.0%

But it is slower than Aten in the e2e run due to different cache behavior. The access to the input data (12544x3072) is LLC latency bound and bottlenecks seen due to the memory synchronization (data transfers and coherence updates across processors). This PR tries to mitigate the problem by cooperatively loading different chunks of input data from different processors that share the input data.

Pull Request resolved: pytorch#135438
Approved by: https://github.com/leslie-fang-intel
@github-actions github-actions bot deleted the gh/jgong5/64/head branch September 17, 2024 01:57
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…ch#134686) (pytorch#135438)

Fix pytorch#134686.

PR pytorch#132729 makes GEMM template faster for one of the GEMMs in xcit_large_24_p8_224:
SingleProcess AUTOTUNE benchmarking takes 1.7088 seconds and 1.9207 seconds precompiling
AUTOTUNE linear_unary(12544x3072, 768x3072, 768)
  cpp_packed_gemm_2 2.9371 ms 100.0%
  _linear_pointwise 3.1584 ms 93.0%

But it is slower than Aten in the e2e run due to different cache behavior. The access to the input data (12544x3072) is LLC latency bound and bottlenecks seen due to the memory synchronization (data transfers and coherence updates across processors). This PR tries to mitigate the problem by cooperatively loading different chunks of input data from different processors that share the input data.

Pull Request resolved: pytorch#135438
Approved by: https://github.com/leslie-fang-intel
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 26, 2024
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 26, 2024
…ing etc. (pytorch#133312)

Indent the template instructions separately from the generated code, for readability. Also, renaming M0,N0,K0 to Mr,Nr,Kr ("r" meaning "register") to consistent naming.

Pull Request resolved: pytorch#133312
Approved by: https://github.com/Skylion007, https://github.com/leslie-fang-intel
ghstack dependencies: pytorch#132729, pytorch#132730
apakbin pushed a commit to apakbin/pytorch that referenced this pull request Feb 14, 2025
…ch#134686) (pytorch#135438)

Fix pytorch#134686.

PR pytorch#132729 makes GEMM template faster for one of the GEMMs in xcit_large_24_p8_224:
SingleProcess AUTOTUNE benchmarking takes 1.7088 seconds and 1.9207 seconds precompiling
AUTOTUNE linear_unary(12544x3072, 768x3072, 768)
  cpp_packed_gemm_2 2.9371 ms 100.0%
  _linear_pointwise 3.1584 ms 93.0%

But it is slower than Aten in the e2e run due to different cache behavior. The access to the input data (12544x3072) is LLC latency bound and bottlenecks seen due to the memory synchronization (data transfers and coherence updates across processors). This PR tries to mitigate the problem by cooperatively loading different chunks of input data from different processors that share the input data.

Pull Request resolved: pytorch#135438
Approved by: https://github.com/leslie-fang-intel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0