-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[Intel GPU] Avoid copy when the input of Matmul is broadcasted #143784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143784
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 2 Unrelated FailuresAs of commit b5def8a with merge base 30cbf13 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
void undo_broadcast_on_batch(at::Tensor& m1, at::Tensor& m2) { | ||
// onednn support one of src and wei broadcasted on batch dim | ||
auto tensor_dim = m1.dim(); | ||
TORCH_CHECK( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadcast for BS dime only serves performance. Pls. do NOT require the check is a must-to-have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted
if (tensor_dim ==2) | ||
return; | ||
auto undo_broadcast = [](at::Tensor& tensor) { | ||
if (tensor.stride(1) == 0 || tensor.stride(2) == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why could a particular stride be zero when a tensor is not a scalar tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we support broadcast on m, n, k dim, so the stride of last two dim cannot be zero. The tensor may not be a scalar
if (m1.stride(0) == 0 && m2.stride(0) == 0) { | ||
// onednn does not support both src and wei broadcasted on batch dim. We copy the smaller one. | ||
if (m1.size(1)<m2.size(2)) { | ||
m1 = m1.contiguous(); | ||
} | ||
else { | ||
m2 = m2.contiguous(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's separate broadcast logic from the contiguous logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meanwhile, pls. help refine is_onednn_mat_strides
a little bit.
dnnl::memory::dims strides = get_onednn_strides(tensor);
int64_t storage_size = 1;
for (size_t dim = 0; dim < tensor_dim; ++dim)
storage_size += (sizes[dim] - 1) * strides[dim];
if (storage_size < tensor.numel())
return false;
The above code snippet could be refined as follows.
if (at::has_internal_overlap(tensor) == at::MemOverlap::Yes) return false;
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
5a3c51b
to
366ed3b
Compare
UT is already covered here https://github.com/pytorch/pytorch/blob/main/test/xpu/test_gemm.py#L313 |
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
fdff049
to
923f683
Compare
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
Successfully rebased |
a456c1b
to
f4910e2
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 4, 4, linux.idc.xpu) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 4 checks: xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 1, 4, linux.idc.xpu), xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 2, 4, linux.idc.xpu), xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 3, 4, linux.idc.xpu), xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 4, 4, linux.idc.xpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…iguous (#144759) We should not always call contiguous on the dst of matmul. We have already removed copy of matmul input in #143784 I also fixed an accuracy issue by using onednn sum post op instead of binary add in the case of inplace to avoid UT failure. Pull Request resolved: #144759 Approved by: https://github.com/EikanWang
…iguous (#144759) We should not always call contiguous on the dst of matmul. We have already removed copy of matmul input in #143784 I also fixed an accuracy issue by using onednn sum post op instead of binary add in the case of inplace to avoid UT failure. Pull Request resolved: #144759 Approved by: https://github.com/EikanWang
Avoid copy when the input of Matmul is 3D and broadcasted on batch dim. oneDNN support implicit broadcast semantics i.e., src can be broadcasted into weight if the corresponding dimension in src is 1 (and vice versa). On Max 1100, timm resmlp_12_224 amp_fp16 inference with bs=128 can improve from 42ms to 13.7 ms on torch.compile and 57.5ms to 32ms on eager mode.
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10