8000 [Intel GPU] Enable fp64 GEMM by ZhiweiYan-96 · Pull Request #140677 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Intel GPU] Enable fp64 GEMM #140677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 32 additions & 136 deletions aten/src/ATen/native/mkldnn/xpu/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Tensor& addmm_out(
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
at::Tensor& result) {
Tensor& result) {
checkBackend("addmm_out", {result, self, mat1, mat2}, Backend::XPU);
TORCH_CHECK(
mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
Expand All @@ -50,11 +50,11 @@ Tensor& addmm_out(
mat1.dtype(),
" != ",
mat2.dtype())
// complex/double case
if (mat1.is_complex() || mat1.scalar_type() == ScalarType::Double) {
TORCH_CHECK(
false, "Double and complex datatype matmul is not supported in oneDNN");
}
// complex case
TORCH_CHECK(
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");

bool is_inplace = result.is_same(self);

std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
result.resize_(result_shape);
Expand Down Expand Up @@ -92,10 +92,17 @@ Tensor& addmm_out(
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear);
}
} else {
// We use post_binary here for adding self matrix.
// To avoid wrong write, here we clone self for inplace case.
if (alpha.to<float>() == 1.f && beta_ == 1.f) {
bias = self;
bias = is_inplace ? self.clone() : self;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to add some comments to elaborate on why the clone is required here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, the comments is added here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to clone if #144759 is merged. We should use post sum instead of post binary in this case

} else {
Tensor binary = self.dim() == 1 ? self.unsqueeze(0) : self;
Tensor binary;
// unsqueeze(0) here is to handle mv cases.
if (is_inplace)
binary = self.dim() == 1 ? self.unsqueeze(0).clone() : self.clone();
else
binary = self.dim() == 1 ? self.unsqueeze(0) : self;
Comment on lines +102 to +105
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to add comments to describe why unsqueeze is required. Or add a utility function for ndim alignment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsqueeze here is used to handle mv related operators. The comments have been updated.

// Tensor binary = self.expand_as(result);
// For post-binary-add, onednn needs binary scale=1.f
// Thus we need the following transformation
Expand Down Expand Up @@ -159,26 +166,13 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
return result;
}

if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
TORCH_CHECK(
false, "Double and complex datatype matmul is not supported in oneDNN");
}
TORCH_CHECK(
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");

onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
return result;
}

Tensor mm(const Tensor& self, const Tensor& mat2) {
auto result = at::empty({0}, self.options());
xpu::mm_out(self, mat2, result);
return result;
}

Tensor mv(const Tensor& self, const Tensor& vec) {
Tensor result = at::empty({self.size(0)}, self.options());
return at::addmv_(result, self, vec, 0, 1);
}

// result = beta * input + alpha * (batch1 @ batch2)
Tensor& baddbmm_out(
const Tensor& input,
Expand All @@ -191,6 +185,8 @@ Tensor& baddbmm_out(
TORCH_CHECK(batch1.dim() == 3, "expected 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "expected 3D tensor");

bool is_inplace = result.is_same(input);

std::vector<int64_t> result_shape = {
batch1.size(0), batch1.size(1), batch2.size(2)};
result.resize_(result_shape);
Expand All @@ -212,11 +208,10 @@ Tensor& baddbmm_out(
" but got:",
input.sizes());

// complex and double case
if (batch1.is_complex() || batch2.scalar_type() == ScalarType::Double) {
TORCH_CHECK(
false, "Double and complex datatype matmul is not supported in oneDNN");
}
// complex case
TORCH_CHECK(
!batch1.is_complex(),
"Complex datatype matmul is not supported in oneDNN");

// general case
onednn::Attr attr;
Expand All @@ -228,7 +223,13 @@ Tensor& baddbmm_out(
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear);
}
} else {
binary = input.dim() < 3 ? input.unsqueeze(0) : input;
// We use post_binary here for adding self matrix.
// To avoid wrong write, here we clone input for inplace case.
if (is_inplace)
binary = input.dim() < 3 ? input.unsqueeze(0).clone() : input.clone();
else
binary = input.dim() < 3 ? input.unsqueeze(0) : input;
// If input is a 1d tensor need be broadcasted, we need unsqueeze twice.
binary = binary.dim() < 3 ? binary.unsqueeze_(0) : binary;
float alpha_ = alpha.to<float>() / beta_;
if (alpha_ != 1.f)
Expand All @@ -241,109 +242,6 @@ Tensor& baddbmm_out(
return result;
}

Tensor& baddbmm_(
Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha) {
TORCH_CHECK(
self.dtype() == batch1.dtype(),
"Input dtypes must be the same, got: input ",
self.dtype(),
", batch1: ",
batch1.dtype(),
", batch2: ",
batch2.dtype());
return at::native::xpu::baddbmm_out(self, batch1, batch2, beta, alpha, self);
}

Tensor baddbmm(
const Tensor& input,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha) {
Tensor r = at::empty({0}, input.options());
TORCH_CHECK(
input.dtype() == batch1.dtype(),
"Input dtypes must be the same, got: input ",
input.dtype(),
", batch1: ",
batch1.dtype(),
", batch2: ",
batch2.dtype());
r = at::native::xpu::baddbmm_out(input, batch1, batch2, beta, alpha, r);
return r;
}

Tensor& addbmm_out(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Intel GPU not support addbmm_out?

Copy link
Collaborator Author
@ZhiweiYan-96 ZhiweiYan-96 Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We does not need to write these glue codes, as cuda/cpu/xpu share an entry at natives_functions.yaml. They share same implementation(like op_stub or composite cases) in at::native::addbmm_out, the implementation in addbmm is general as it do the job by calling addmm which we have codes.

const Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha,
Tensor& out) {
checkBackend("addbmm_out", {out, self, batch1, batch2}, Backend::XPU);
TORCH_CHECK(
batch1.dim() == 3 && batch2.dim() == 3,
"Batch tensors should be 3D, got dimensions ",
batch1.dim(),
" and ",
batch2.dim());
if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
TORCH_CHECK(
false, "Double and complex datatype matmul is not supported in oneDNN");
}

out.resize_({batch1.size(1), batch2.size(2)});
if (alpha.to<float>() == 0.f || batch1.numel() == 0 || batch2.numel() == 0) {
out.resize_({batch1.size(1), batch2.size(2)});
if (out.numel() == 0)
return out;

if (self.defined() && beta.to<float>() != 0.f) {
out = at::mul_out(
out, self, at::native::wrapped_scalar_tensor(at::Scalar(beta)));
} else {
out.zero_();
}
return out;
}

Tensor b1;
if (batch1.size(0) > 1) {
b1 = batch1.transpose(0, 1).contiguous().view({batch1.size(1), -1});
} else {
b1 = batch1.contiguous().view({batch1.size(1), -1});
}
auto b2 = batch2.contiguous().view({-1, batch2.size(2)});
at::native::xpu::addmm_out(self, b1, b2, beta, alpha, out);

return out;
}

Tensor& addbmm_(
Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha) {
at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, self);
return self;
}

Tensor addbmm(
const Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha) {
Tensor out = at::empty({0}, self.options());
at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, out);
return out;
}

Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
checkBackend("bmm_out", {result, self, batch2}, Backend::XPU);
TORCH_CHECK(self.dim() == 3, "expected 3D tensor");
Expand All @@ -356,10 +254,8 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
return result;
}

if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
TORCH_CHECK(
false, "Double and complex datatype matmul is not supported in oneDNN");
}
TORCH_CHECK(
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
return result;
}
Expand Down
9 changes: 5 additions & 4 deletions aten/src/ATen/native/mkldnn/xpu/detail/Matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <Attr.h>
#include <Utils.h>

#include <c10/core/ScalarType.h>
#include <oneapi/dnnl/dnnl.hpp>

namespace at::native::onednn {
Expand Down Expand Up @@ -109,9 +110,9 @@ sycl::event matmul(
b = b.contiguous(); // avoid reorder 2 times

// xpu matmul support both ab/ba shape for m2 tensor, we don't check any more
auto m1_usr_dt = get_onednn_dtype(m1);
auto m2_usr_dt = get_onednn_dtype(m2);
auto dst_usr_dt = get_onednn_dtype(dst);
auto m1_usr_dt = get_onednn_dtype_include_double(m1);
auto m2_usr_dt = get_onednn_dtype_include_double(m2);
auto dst_usr_dt = get_onednn_dtype_include_double(dst);

auto m1_dt = m1_usr_dt;
auto m2_dt = m2_usr_dt;
Expand Down Expand Up @@ -165,7 +166,7 @@ sycl::event matmul(

if (with_bias) {
bias_dims = get_onednn_dims(b);
bias_dt = get_onednn_dtype(b);
bias_dt = get_onednn_dtype_include_double(b);
bias_strides = get_onednn_strides(b);
}

Expand Down
11 changes: 6 additions & 5 deletions aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,8 @@ dnnl::memory::data_type get_onednn_dtype_include_double(
}

bool is_supported_onednn_dtype(const at::Tensor& tensor) {
return get_onednn_dtype(tensor, /*allow_undef*/ true) ==
dnnl::memory::data_type::undef
? false
: true;
return get_onednn_dtype_include_double(tensor) !=
dnnl::memory::data_type::undef;
}

dnnl::memory::dims get_onednn_dims(const at::Tensor& tensor) {
Expand All @@ -116,7 +114,10 @@ dnnl::memory::dims get_onednn_strides(const at::Tensor& tensor) {

dnnl::memory::desc get_onednn_md(const at::Tensor& tensor) {
Tensor t = tensor.sizes().empty() ? tensor.unsqueeze(0) : tensor;
return {get_onednn_dims(t), get_onednn_dtype(t), get_onednn_strides(t)};
return {
get_onednn_dims(t),
get_onednn_dtype_include_double(t),
get_onednn_strides(t)};
}

bool onednn_strides_check(const Tensor& src) {
Expand Down
69 changes: 0 additions & 69 deletions test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,79 +273,10 @@ def format_op(op):
"torch.ops.aten._efficient_attention_forward": {f16, f32},
"to_sparse": {f32, f64},
"linalg.eig": {f32, f64},
"linalg.eigvals": {f64},
# Double and complex datatype matmul is not supported in oneDNN
"__rmatmul__": {f64},
("addmm", "decomposed"): {f64},
"addr": {f64},
"baddbmm": {f64},
"bmm": {f64},
"byte": {f16, f32},
"cdist": {f64},
"corrcoef": {f64},
"cov": {f64},
"einsum": {f64},
"inner": {f64},
"linalg.cholesky_ex": {f64},
"linalg.cholesky": {f64},
"linalg.ldl_factor_ex": {f64},
"linalg.ldl_factor": {f64},
"linalg.ldl_solve": {f64},
"linalg.matrix_power": {f64},
"linalg.multi_dot": {f64},
"matmul": {f64},
"mm": {f64},
"mv": {f64},
"nn.functional.bilinear": {f64},
"nn.functional.linear": {f64},
"pca_lowrank": {f64},
"svd_lowrank": {f64},
"tensordot": {f64},
"triangular_solve": {f64},
"svd": {f64},
"qr": {f64},
"pinverse": {f64},
"ormqr": {f64},
("norm", "nuc"): {f64},
"lu": {f64},
"lu_solve": {f64},
"logdet": {f64},
"linalg.tensorsolve": {f64},
"linalg.tensorinv": {f64},
"linalg.svdvals": {f64},
"linalg.svd": {f64},
"linalg.solve": {f64},
"linalg.solve_triangular": {f64},
"linalg.solve_ex": {f64},
"linalg.slogdet": {f64},
"linalg.qr": {f64},
"linalg.pinv": {f64},
("linalg.pinv", "hermitian"): {f64},
("linalg.pinv", "singular"): {f64},
"linalg.norm": {f64},
("linalg.norm", "subgradients_at_zero"): {f64},
"linalg.matrix_rank": {f64},
("linalg.matrix_rank", "hermitian"): {f64},
"linalg.matrix_norm": {f64},
"linalg.lu": {f64},
"linalg.lu_solve": {f64},
"linalg.lu_factor": {f64},
"linalg.lu_factor_ex": {f64},
"linalg.lstsq": {f64},
("linalg.lstsq", "grad_oriented"): {f64},
"linalg.inv": {f64},
"linalg.inv_ex": {f64},
"linalg.householder_product": {f64},
"linalg.eigvalsh": {f64},
"linalg.eigh": {f64},
"linalg.det": {f64},
"linalg.cond": {f64},
"geqrf": {f64},
"cholesky_solve": {f64},
"cholesky_inverse": {f64},
# could not create a primitive
"addbmm": {f64},
"addmm": {f64},
"addmv": {f64},
# could not create a primitive descriptor for
# a deconvolution forward propagation primitive
Expand Down
Loading
Loading
0