-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[Intel GPU] Enable fp64 GEMM #140677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Intel GPU] Enable fp64 GEMM #140677
Changes from all commits
7cf6f34
d84808b
d5e9312
7a85768
4c7186e
52350c1
df0792d
7fdc999
377490e
6618885
c62a887
204195d
aceae45
93a5dd0
ea7a836
fac938e
54d3332
288e4ab
28dde89
ca9cba3
d54377d
3039b88
4a243f7
b358c61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ Tensor& addmm_out( | |
const Tensor& mat2, | ||
const Scalar& beta, | ||
const Scalar& alpha, | ||
at::Tensor& result) { | ||
Tensor& result) { | ||
checkBackend("addmm_out", {result, self, mat1, mat2}, Backend::XPU); | ||
TORCH_CHECK( | ||
mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor"); | ||
|
@@ -50,11 +50,11 @@ Tensor& addmm_out( | |
mat1.dtype(), | ||
" != ", | ||
mat2.dtype()) | ||
// complex/double case | ||
if (mat1.is_complex() || mat1.scalar_type() == ScalarType::Double) { | ||
TORCH_CHECK( | ||
false, "Double and complex datatype matmul is not supported in oneDNN"); | ||
} | ||
// complex case | ||
TORCH_CHECK( | ||
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN"); | ||
|
||
bool is_inplace = result.is_same(self); | ||
|
||
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)}; | ||
result.resize_(result_shape); | ||
|
@@ -92,10 +92,17 @@ Tensor& addmm_out( | |
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear); | ||
} | ||
} else { | ||
// We use post_binary here for adding self matrix. | ||
// To avoid wrong write, here we clone self for inplace case. | ||
if (alpha.to<float>() == 1.f && beta_ == 1.f) { | ||
bias = self; | ||
bias = is_inplace ? self.clone() : self; | ||
} else { | ||
Tensor binary = self.dim() == 1 ? self.unsqueeze(0) : self; | ||
Tensor binary; | ||
// unsqueeze(0) here is to handle mv cases. | ||
if (is_inplace) | ||
binary = self.dim() == 1 ? self.unsqueeze(0).clone() : self.clone(); | ||
else | ||
binary = self.dim() == 1 ? self.unsqueeze(0) : self; | ||
Comment on lines
+102
to
+105
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to add comments to describe why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
// Tensor binary = self.expand_as(result); | ||
// For post-binary-add, onednn needs binary scale=1.f | ||
// Thus we need the following transformation | ||
|
@@ -159,26 +166,13 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) { | |
return result; | ||
} | ||
|
||
if (self.is_complex() || self.scalar_type() == ScalarType::Double) { | ||
TORCH_CHECK( | ||
false, "Double and complex datatype matmul is not supported in oneDNN"); | ||
} | ||
TORCH_CHECK( | ||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN"); | ||
|
||
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr()); | ||
return result; | ||
} | ||
|
||
Tensor mm(const Tensor& self, const Tensor& mat2) { | ||
auto result = at::empty({0}, self.options()); | ||
xpu::mm_out(self, mat2, result); | ||
return result; | ||
} | ||
|
||
Tensor mv(const Tensor& self, const Tensor& vec) { | ||
Tensor result = at::empty({self.size(0)}, self.options()); | ||
return at::addmv_(result, self, vec, 0, 1); | ||
} | ||
|
||
// result = beta * input + alpha * (batch1 @ batch2) | ||
Tensor& baddbmm_out( | ||
const Tensor& input, | ||
|
@@ -191,6 +185,8 @@ Tensor& baddbmm_out( | |
TORCH_CHECK(batch1.dim() == 3, "expected 3D tensor"); | ||
TORCH_CHECK(batch2.dim() == 3, "expected 3D tensor"); | ||
|
||
bool is_inplace = result.is_same(input); | ||
|
||
std::vector<int64_t> result_shape = { | ||
batch1.size(0), batch1.size(1), batch2.size(2)}; | ||
result.resize_(result_shape); | ||
|
@@ -212,11 +208,10 @@ Tensor& baddbmm_out( | |
" but got:", | ||
input.sizes()); | ||
|
||
// complex and double case | ||
if (batch1.is_complex() || batch2.scalar_type() == ScalarType::Double) { | ||
TORCH_CHECK( | ||
false, "Double and complex datatype matmul is not supported in oneDNN"); | ||
} | ||
// complex case | ||
TORCH_CHECK( | ||
!batch1.is_complex(), | ||
"Complex datatype matmul is not supported in oneDNN"); | ||
|
||
// general case | ||
onednn::Attr attr; | ||
|
@@ -228,7 +223,13 @@ Tensor& baddbmm_out( | |
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear); | ||
} | ||
} else { | ||
binary = input.dim() < 3 ? input.unsqueeze(0) : input; | ||
// We use post_binary here for adding self matrix. | ||
// To avoid wrong write, here we clone input for inplace case. | ||
if (is_inplace) | ||
binary = input.dim() < 3 ? input.unsqueeze(0).clone() : input.clone(); | ||
else | ||
binary = input.dim() < 3 ? input.unsqueeze(0) : input; | ||
// If input is a 1d tensor need be broadcasted, we need unsqueeze twice. | ||
binary = binary.dim() < 3 ? binary.unsqueeze_(0) : binary; | ||
float alpha_ = alpha.to<float>() / beta_; | ||
if (alpha_ != 1.f) | ||
|
@@ -241,109 +242,6 @@ Tensor& baddbmm_out( | |
return result; | ||
} | ||
|
||
Tensor& baddbmm_( | ||
Tensor& self, | ||
const Tensor& batch1, | ||
const Tensor& batch2, | ||
const Scalar& beta, | ||
const Scalar& alpha) { | ||
TORCH_CHECK( | ||
self.dtype() == batch1.dtype(), | ||
"Input dtypes must be the same, got: input ", | ||
self.dtype(), | ||
", batch1: ", | ||
batch1.dtype(), | ||
", batch2: ", | ||
batch2.dtype()); | ||
return at::native::xpu::baddbmm_out(self, batch1, batch2, beta, alpha, self); | ||
} | ||
|
||
Tensor baddbmm( | ||
const Tensor& input, | ||
const Tensor& batch1, | ||
const Tensor& batch2, | ||
const Scalar& beta, | ||
const Scalar& alpha) { | ||
Tensor r = at::empty({0}, input.options()); | ||
TORCH_CHECK( | ||
input.dtype() == batch1.dtype(), | ||
"Input dtypes must be the same, got: input ", | ||
input.dtype(), | ||
", batch1: ", | ||
batch1.dtype(), | ||
", batch2: ", | ||
batch2.dtype()); | ||
r = at::native::xpu::baddbmm_out(input, batch1, batch2, beta, alpha, r); | ||
return r; | ||
} | ||
|
||
Tensor& addbmm_out( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does Intel GPU not support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We does not need to write these glue codes, as cuda/cpu/xpu share an entry at |
||
const Tensor& self, | ||
const Tensor& batch1, | ||
const Tensor& batch2, | ||
const Scalar& beta, | ||
const Scalar& alpha, | ||
Tensor& out) { | ||
checkBackend("addbmm_out", {out, self, batch1, batch2}, Backend::XPU); | ||
TORCH_CHECK( | ||
batch1.dim() == 3 && batch2.dim() == 3, | ||
"Batch tensors should be 3D, got dimensions ", | ||
batch1.dim(), | ||
" and ", | ||
batch2.dim()); | ||
if (self.is_complex() || self.scalar_type() == ScalarType::Double) { | ||
TORCH_CHECK( | ||
false, "Double and complex datatype matmul is not supported in oneDNN"); | ||
} | ||
|
||
out.resize_({batch1.size(1), batch2.size(2)}); | ||
if (alpha.to<float>() == 0.f || batch1.numel() == 0 || batch2.numel() == 0) { | ||
out.resize_({batch1.size(1), batch2.size(2)}); | ||
if (out.numel() == 0) | ||
return out; | ||
|
||
if (self.defined() && beta.to<float>() != 0.f) { | ||
out = at::mul_out( | ||
out, self, at::native::wrapped_scalar_tensor(at::Scalar(beta))); | ||
} else { | ||
out.zero_(); | ||
} | ||
return out; | ||
} | ||
|
||
Tensor b1; | ||
if (batch1.size(0) > 1) { | ||
b1 = batch1.transpose(0, 1).contiguous().view({batch1.size(1), -1}); | ||
} else { | ||
b1 = batch1.contiguous().view({batch1.size(1), -1}); | ||
} | ||
auto b2 = batch2.contiguous().view({-1, batch2.size(2)}); | ||
at::native::xpu::addmm_out(self, b1, b2, beta, alpha, out); | ||
|
||
return out; | ||
} | ||
|
||
Tensor& addbmm_( | ||
Tensor& self, | ||
const Tensor& batch1, | ||
const Tensor& batch2, | ||
const Scalar& beta, | ||
const Scalar& alpha) { | ||
at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, self); | ||
return self; | ||
} | ||
|
||
Tensor addbmm( | ||
const Tensor& self, | ||
const Tensor& batch1, | ||
const Tensor& batch2, | ||
const Scalar& beta, | ||
const Scalar& alpha) { | ||
Tensor out = at::empty({0}, self.options()); | ||
at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, out); | ||
return out; | ||
} | ||
|
||
Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) { | ||
checkBackend("bmm_out", {result, self, batch2}, Backend::XPU); | ||
TORCH_CHECK(self.dim() == 3, "expected 3D tensor"); | ||
|
@@ -356,10 +254,8 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) { | |
return result; | ||
} | ||
|
||
if (self.is_complex() || self.scalar_type() == ScalarType::Double) { | ||
TORCH_CHECK( | ||
false, "Double and complex datatype matmul is not supported in oneDNN"); | ||
} | ||
TORCH_CHECK( | ||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN"); | ||
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr()); | ||
return result; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to add some comments to elaborate on why the clone is required here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, the comments is added here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to clone if #144759 is merged. We should use post sum instead of post binary in this case