-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[Intel GPU] Enable fp64 GEMM #140677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Intel GPU] Enable fp64 GEMM #140677
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140677
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit b358c61 with merge base 880e176 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -8,6 +8,7 @@ | |||
#include <Utils.h> | |||
|
|||
#include <oneapi/dnnl/dnnl.hpp> | |||
#include "c10/core/ScalarType.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include "c10/core/ScalarType.h" | |
#include <c10/core/ScalarType.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for suggestions, the code is changed here
TORCH_CHECK( | ||
8000 false, "Double and complex datatype matmul is not supported in oneDNN"); | ||
if (self.is_complex()) { | ||
AT_ERROR("Complex datatype matmul is not supported in oneDNN"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AT_ERROR
has been deprecated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for suggestion, AT_ERROR
in file is changed to TORCH_CHECK
if (self.is_complex() || self.scalar_type() == ScalarType::Double) { | ||
TORCH_CHECK( | ||
false, "Double and complex datatype matmul is not supported in oneDNN"); | ||
if (self.is_complex()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
`TORCH_CHECK(!self.is_complex(), "error message");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified
// complex case | ||
if (mat1.is_complex()) { | ||
AT_ERROR("Complex datatype matmul is not supported in oneDNN"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// complex case | |
if (mat1.is_complex()) { | |
AT_ERROR("Complex datatype matmul is not supported in oneDNN"); | |
} | |
// complex case | |
TORCH_CHECK(!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified
@@ -277,73 +280,6 @@ Tensor baddbmm( | |||
return r; | |||
} | |||
|
|||
Tensor& addbmm_out( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Intel GPU not support addbmm_out
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We does not need to write these glue codes, as cuda/cpu/xpu share an entry at natives_functions.yaml
. They share same implementation(like op_stub
or composite
cases) in at::native::addbmm_out
, the implementation in addbmm
is general as it do the job by calling addmm
which we have codes.
@@ -93,9 +94,13 @@ Tensor& addmm_out( | |||
} | |||
} else { | |||
if (alpha.to<float>() == 1.f && beta_ == 1.f) { | |||
bias = self; | |||
bias = is_inplace ? self.clone() : self; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to add some comments to elaborate on why the clone is required here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, the comments is added here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to clone if #144759 is merged. We should use post sum instead of post binary in this case
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 2, 4, linux.idc.xpu) Details for Dev Infra teamRaised by workflow job |
@pytorchbot rebase -b main |
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
Successfully rebased |
|
@guangyey There should be some new skipped uts were added, but this PR fixs them. I retriggered ci and wait for all fixed uts shown in CI results. After that, I will fix them at single commit. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #140677 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/desertfire
Pull Request resolved: #140677 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/desertfire
Pull Request resolved: pytorch#140677 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/desertfire
Stack from ghstack (oldest at bottom):
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @gujinghui @fengyuan14 @guangyey