8000 [inductor] let inplace-padding support cpp-wrapper by shunting314 · Pull Request #145325 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] let inplace-padding support cpp-wrapper #145325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

shunting314
Copy link
Contributor
@shunting314 shunting314 commented Jan 21, 2025

Stack from ghstack (oldest at bottom):

Some context: Inplace padding is an optimization to do padding in place. E.g., if a tensor has size [2048, 2047] and stride [2048, 1]. When we need pad one extra element to the end of each row (e.g. during mm padding), we can just reuse the original tensor and do the padding inplace. This saves memory and bandwidth. One caveat for this optimization is, PyTorch does not allocate 2048 elements for the last row of the original tensor. It only allocate 2047 elements. So assuming the last row having enough space for 2048 elements may be wrong and cause OOB memory access (although I never see this happen maybe due to overallocation in the CUDACachingAllocation, this should better be fixed).

The fix is when we allocate the tensor, instead of doing something like:

  buf0 = randn_strided([2048, 2047], [2048, 1])

we do some small overallocation

  buf0 = randn_strided([2048, 2048], [2048, 1]).as_strided([2048, 2047], [2048, 1])

cpp_wrapper needs special handling since memory allocation goes thru different code path to python wrapper.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

Copy link
pytorch-bot bot commented Jan 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145325

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Pending, 2 Unrelated Failures

As of commit 7140150 with merge base 8581163 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

shunting314 added a commit that referenced this pull request Jan 21, 2025
ghstack-source-id: b3a22f3
Pull Request resolved: #145325
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 00:19 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 00:19 Inactive
Copy link
Contributor
@desertfire desertfire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM other than the new C shim function signature.

stride_array_var,
]
self.wrapper_call.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_as_strided({', '.join(args)}));"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on whether this will be triggered for internal models, if yes, there is a tricky FC issue here since aoti_torch_as_strided is a newly introduced C shim function. I assume it's less likely the case, but let's see.

Some context: Inplace padding is an optimization to do padding in place. E.g., if a tensor has size [2048, 2047] and stride [2048, 1]. When we need pad one extra element to the end of each row (e.g. during mm padding), we can just reuse the original tensor and do the padding inplace. This saves memory and bandwidth.  One caveat for this optimization is, PyTorch does not allocate 2048 elements for the last row of the original tensor. It only allocate 2047 elements. So assuming the last row having enough space for 2048 elements may be wrong and cause OOB memory access (although I never see this happen maybe due to overallocation in the CUDACachingAllocation, this should better be fixed).

The fix is when we allocate the tensor, instead of doing something like:
```
  buf0 = randn_strided([2048, 2047], [2048, 1])
```
we do some small overallocation
```
  buf0 = randn_strided([2048, 2048], [2048, 1]).as_strided([2048, 2047], [2048, 1])
```

cpp_wrapper needs special handling since memory allocation goes thru different code path to python wrapper.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jan 22, 2025
ghstack-source-id: a0ca4ec
Pull Request resolved: #145325
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 20:52 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 20:52 Inactive
@shunting314
Copy link
Contributor Author

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 23, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: pull / linux-focal-py3.9-clang10 / test (dynamo_wrapped, 3, 3, linux.2xlarge), inductor / unit-test / cuda12.4-py3.12-gcc9-sm86 / build, inductor / unit-test / linux-jammy-cpu-py3.12-gcc11-inductor-triton-cpu / test (inductor-triton-cpu, 1, 1, linux.12xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 23, 2025 07:42 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 23, 2025 07:42 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 23, 2025 07:42 Inactive
pytorchmergebot pushed a commit that referenced this pull request Jan 23, 2025
We use `cpu_tensor.copy_(gpu_tensor)` to clone mutated kernel arguments for autotuning. The purpose is to avoid increasing peak memory due to the clone. But if `gpu_tensor` is not contiguous, this `copy_` will need allocate an temporary tensor on GPU to store a contiguous copy of `gpu_tensor`:

https://github.com/pytorch/pytorch/blob/6e53588789c48682c7da969de9cbace67a1dd9f3/aten/src/ATen/native/cuda/Copy.cu#L322-L334

Here is a standalone script to illustrate this behavior: https://gist.github.com/shunting314/812a848dc67b1d674ae42415a7a462c8 . The script report 6GB rather than 3GB peak memory usage.

Note that, with all the following efforts
1. donated buffer
2. inplace padding
3. this PR

We save 3GB peak memory (18.6GB -> 15.5GB) for GPT2 model for torch.compile.

The peak memory of GPT2 is like a '...\_M\_...' shape. There are 2 places that we reach the peak. Donated buffer remove the first peak by computing grad_softmax inplace, and inplace padding removes the second peak by not allocating an extra buffer for mm-padding.

Before all these optimizations, the peak memory is 18.6GB for GPT2 with torch.compile.
With 1 & 2, the peak memory is
1. 17.7GB with a cold cache
2. 15.5GB with a warm cache (since the autotuning overhead is skipped)

With 1 & 2 & 3, we save 3GB peak memory  (18.6GB -> 15.5GB) no matter if autotuning happens or not

Pull Request resolved: #145410
Approved by: https://github.com/masnesral, https://github.com/jansel
ghstack dependencies: #140249, #145325
@github-actions github-actions bot deleted the gh/shunting314/193/head branch February 23, 2025 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0