8000 [Intel GPU] Enable fp64 GEMM (#140677) · pytorch/pytorch@ae5f7fe · GitHub
[go: up one dir, main page]

Skip to content

Commit ae5f7fe

Browse files
ZhiweiYan-96pytorchmergebot
authored andcommitted
[Intel GPU] Enable fp64 GEMM (#140677)
Pull Request resolved: #140677 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/desertfire
1 parent 2b30e94 commit ae5f7fe

File tree

5 files changed

+61
-252
lines changed

5 files changed

+61
-252
lines changed

aten/src/ATen/native/mkldnn/xpu/Blas.cpp

Lines changed: 32 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Tensor& addmm_out(
2727
const Tensor& mat2,
2828
const Scalar& beta,
2929
const Scalar& alpha,
30-
at::Tensor& result) {
30+
Tensor& result) {
3131
checkBackend("addmm_out", {result, self, mat1, mat2}, Backend::XPU);
3232
TORCH_CHECK(
3333
mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
@@ -50,11 +50,11 @@ Tensor& addmm_out(
5050
mat1.dtype(),
5151
" != ",
5252
mat2.dtype())
53-
// complex/double case
54-
if (mat1.is_complex() || mat1.scalar_type() == ScalarType::Double) {
55-
TORCH_CHECK(
56-
false, "Double and complex datatype matmul is not supported in oneDNN");
57-
}
53+
// complex case
54+
TORCH_CHECK(
55+
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");
56+
57+
bool is_inplace = result.is_same(self);
5858

5959
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
6060
result.resize_(result_shape);
@@ -92,10 +92,17 @@ Tensor& addmm_out(
9292
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear);
9393
}
9494
} else {
95+
// We use post_binary here for adding self matrix.
96+
// To avoid wrong write, here we clone self for inplace case.
9597
if (alpha.to<float>() == 1.f && beta_ == 1.f) {
96-
bias = self;
98+
bias = is_inplace ? self.clone() : self;
9799
} else {
98-
Tensor binary = self.dim() == 1 ? self.unsqueeze(0) : self;
100+
Tensor binary;
101+
// unsqueeze(0) here is to handle mv cases.
102+
if (is_inplace)
103+
binary = self.dim() == 1 ? self.unsqueeze(0).clone() : self.clone();
104+
else
105+
binary = self.dim() == 1 ? self.unsqueeze(0) : self;
99106
// Tensor binary = self.expand_as(result);
100107
// For post-binary-add, onednn needs binary scale=1.f
101108
// Thus we need the following transformation
@@ -159,26 +166,13 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
159166
return result;
160167
}
161168

162-
if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
163-
TORCH_CHECK(
164-
false, "Double and complex datatype matmul is not supported in oneDNN");
165-
}
169+
TORCH_CHECK(
170+
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
166171

167172
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
168173
return result;
169174
}
170175

171-
Tensor mm(const Tensor& self, const Tensor& mat2) {
172-
auto result = at::empty({0}, self.options());
173-
xpu::mm_out(self, mat2, result);
174-
return result;
175-
}
176-
177-
Tensor mv(const Tensor& self, const Tensor& vec) {
178-
Tensor result = at::empty({self.size(0)}, self.options());
179-
return at::addmv_(result, self, vec, 0, 1);
180-
}
181-
182176
// result = beta * input + alpha * (batch1 @ batch2)
183177
Tensor& baddbmm_out(
184178
const Tensor& input,
@@ -191,6 +185,8 @@ Tensor& baddbmm_out(
191185
TORCH_CHECK(batch1.dim() == 3, "expected 3D tensor");
192186
TORCH_CHECK(batch2.dim() == 3, "expected 3D tensor");
193187

188+
bool is_inplace = result.is_same(input);
189+
194190
std::vector<int64_t> result_shape = {
195191
batch1.size(0), batch1.size(1), batch2.size(2)};
196192
result.resize_(result_shape);
@@ -212,11 +208,10 @@ Tensor& baddbmm_out(
212208
" but got:",
213209
input.sizes());
214210

215-
// complex and double case
216-
if (batch1.is_complex() || batch2.scalar_type() == ScalarType::Double) {
217-
TORCH_CHECK(
218-
false, "Double and complex datatype matmul is not supported in oneDNN");
219-
}
211+
// complex case
212+
TORCH_CHECK(
213+
!batch1.is_complex(),
214+
"Complex datatype matmul is not supported in oneDNN");
220215

221216
// general case
222217
onednn::Attr attr;
@@ -228,7 +223,13 @@ Tensor& baddbmm_out(
228223
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear);
229224
}
230225
} else {
231-
binary = input.dim() < 3 ? input.unsqueeze(0) : input;
226+
// We use post_binary here for adding self matrix.
227+
// To avoid wrong write, here we clone input for inplace case.
228+
if (is_inplace)
229+
binary = input.dim() < 3 ? input.unsqueeze(0).clone() : input.clone();
230+
else
231+
binary = input.dim() < 3 ? input.unsqueeze(0) : input;
232+
// If input is a 1d tensor need be broadcasted, we need unsqueeze twice.
232233
binary = binary.dim() < 3 ? binary.unsqueeze_(0) : binary;
233234
float alpha_ = alpha.to<float>() / beta_;
234235
if (alpha_ != 1.f)
@@ -241,109 +242,6 @@ Tensor& baddbmm_out(
241242
return result;
242243
}
243244

244-
Tensor& baddbmm_(
245-
Tensor& self,
246-
const Tensor& batch1,
247-
const Tensor& batch2,
248-
const Scalar& beta,
249-
const Scalar& alpha) {
250-
TORCH_CHECK(
251-
self.dtype() == batch1.dtype(),
252-
"Input dtypes must be the same, got: input ",
253-
self.dtype(),
254-
", batch1: ",
255-
batch1.dtype(),
256-
", batch2: ",
257-
batch2.dtype());
258-
return at::native::xpu::baddbmm_out(self, batch1, batch2, beta, alpha, self);
259-
}
260-
261-
Tensor baddbmm(
262-
const Tensor& input,
263-
const Tensor& batch1,
264-
const Tensor& batch2,
265-
const Scalar& beta,
266-
const Scalar& alpha) {
267-
Tensor r = at::empty({0}, input.options());
268-
TORCH_CHECK(
269-
input.dtype() == batch1.dtype(),
270-
"Input dtypes must be the same, got: input ",
271-
input.dtype(),
272-
", batch1: ",
273-
batch1.dtype(),
274-
", batch2: ",
275-
batch2.dtype());
276-
r = at::native::xpu::baddbmm_out(input, batch1, batch2, beta, alpha, r);
277-
return r;
278-
}
279-
280-
Tensor& addbmm_out(
281-
const Tensor& self,
282-
const Tensor& batch1,
283-
const Tensor& batch2,
284-
const Scalar& beta,
285-
const Scalar& alpha,
286-
Tensor& out) {
287-
checkBackend("addbmm_out", {out, self, batch1, batch2}, Backend::XPU);
288-
TORCH_CHECK(
289-
batch1.dim() == 3 && batch2.dim() == 3,
290-
"Batch tensors should be 3D, got dimensions ",
291-
batch1.dim(),
292-
" and ",
293-
batch2.dim());
294-
if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
295-
TORCH_CHECK(
296-
false, "Double and complex datatype matmul is not supported in oneDNN");
297-
}
298-
299-
out.resize_({batch1.size(1), batch2.size(2)});
300-
if (alpha.to<float>() == 0.f || batch1.numel() == 0 || batch2.numel() == 0) {
301-
out.resize_({batch1.size(1), batch2.size(2)});
302-
if (out.numel() == 0)
303-
return out;
304-
305-
if (self.defined() && beta.to<float>() != 0.f) {
306-
out = at::mul_out(
307-
out, self, at::native::wrapped_scalar_tensor(at::Scalar(beta)));
308-
} else {
309-
out.zero_();
310-
}
311-
return out;
312-
}
313-
314-
Tensor b1;
315-
if (batch1.size(0) > 1) {
316-
b1 = batch1.transpose(0, 1).contiguous().view({batch1.size(1), -1});
317-
} else {
318-
b1 = batch1.contiguous().view({batch1.size(1), -1});
319-
}
320-
auto b2 = batch2.contiguous().view({-1, batch2.size(2)});
321-
at::native::xpu::addmm_out(self, b1, b2, beta, alpha, out);
322-
323-
return out;
324-
}
325-
326-
Tensor& addbmm_(
327-
Tensor& self,
328-
const Tensor& batch1,
329-
const Tensor& batch2,
330-
const Scalar& beta,
331-
const Scalar& alpha) {
332-
at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, self);
333-
return self;
334-
}
335-
336-
Tensor addbmm(
337-
const Tensor& self,
338-
const Tensor& batch1,
339-
const Tensor& batch2,
340-
const Scalar& beta,
341-
const Scalar& alpha) {
342-
Tensor out = at::empty({0}, self.options());
343-
at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, out);
344-
return out;
345-
}
346-
347245
Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
348246
checkBackend("bmm_out", {result, self, batch2}, Backend::XPU);
349247
TORCH_CHECK(self.dim() == 3, "expected 3D tensor");
@@ -356,10 +254,8 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
356254
return result;
357255
}
358256

359-
if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
360-
TORCH_CHECK(
361-
false, "Double and complex datatype matmul is not supported in oneDNN");
362-
}
257+
TORCH_CHECK(
258+
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
363259
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
364260
return result;
365261
}

aten/src/ATen/native/mkldnn/xpu/detail/Matmul.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <Attr.h>
88
#include <Utils.h>
99

10+
#include <c10/core/ScalarType.h>
1011
#include <oneapi/dnnl/dnnl.hpp>
1112

1213
namespace at::native::onednn {
@@ -109,9 +110,9 @@ sycl::event matmul(
109110
b = b.contiguous(); // avoid reorder 2 times
110111

111112
// xpu matmul support both ab/ba shape for m2 tensor, we don't check any more
112-
auto m1_usr_dt = get_onednn_dtype(m1);
113-
auto m2_usr_dt = get_onednn_dtype(m2);
114-
auto dst_usr_dt = get_onednn_dtype(dst);
113+
auto m1_usr_dt = get_onednn_dtype_include_double(m1);
114+
auto m2_usr_dt = get_onednn_dtype_include_double(m2);
115+
auto dst_usr_dt = get_onednn_dtype_include_double(dst);
115116

116117
auto m1_dt = m1_usr_dt;
117118
auto m2_dt = m2_usr_dt;
@@ -165,7 +166,7 @@ sycl::event matmul(
165166

166167
if (with_bias) {
167168
bias_dims = get_onednn_dims(b);
168-
bias_dt = get_onednn_dtype(b);
169+
bias_dt = get_onednn_dtype_include_double(b);
169170
bias_strides = get_onednn_strides(b);
170171
}
171172

aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,8 @@ dnnl::memory::data_type get_onednn_dtype_include_double(
9797
}
9898

9999
bool is_supported_onednn_dtype(const at::Tensor& tensor) {
100-
return get_onednn_dtype(tensor, /*allow_undef*/ true) ==
101-
dnnl::memory::data_type::undef
102-
? false
103-
: true;
100+
return get_onednn_dtype_include_double(tensor) !=
101+
dnnl::memory::data_type::undef;
104102
}
105103

106104
dnnl::memory::dims get_onednn_dims(const at::Tensor& tensor) {
@@ -119,7 +117,10 @@ dnnl::memory::dims get_onednn_strides(const at::Tensor& tensor) {
119117

120118
dnnl::memory::desc get_onednn_md(const at::Tensor& tensor) {
121119
Tensor t = tensor.sizes().empty() ? tensor.unsqueeze(0) : tensor;
122-
return {get_onednn_dims(t), get_onednn_dtype(t), get_onednn_strides(t)};
120+
return {
121+
get_onednn_dims(t),
122+
get_onednn_dtype_include_double(t),
123+
get_onednn_strides(t)};
123124
}
124125

125126
bool onednn_strides_check(const Tensor& src) {

test/inductor/test_torchinductor_opinfo.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -273,79 +273,10 @@ def format_op(op):
273273
"torch.ops.aten._efficient_attention_forward": {f16, f32},
274274
"to_sparse": {f32, f64},
275275
"linalg.eig": {f32, f64},
276-
"linalg.eigvals": {f64},
277276
# Double and complex datatype matmul is not supported in oneDNN
278-
"__rmatmul__": {f64},
279-
("addmm", "decomposed"): {f64},
280-
"addr": {f64},
281-
"baddbmm": {f64},
282-
"bmm": {f64},
283277
"byte": {f16, f32},
284-
"cdist": {f64},
285-
"corrcoef": {f64},
286-
"cov": {f64},
287-
"einsum": {f64},
288-
"inner": {f64},
289-
"linalg.cholesky_ex": {f64},
290-
"linalg.cholesky": {f64},
291-
"linalg.ldl_factor_ex": {f64},
292-
"linalg.ldl_factor": {f64},
293-
"linalg.ldl_solve": {f64},
294-
"linalg.matrix_power": {f64},
295-
"linalg.multi_dot": {f64},
296-
"matmul": {f64},
297-
"mm": {f64},
298-
"mv": {f64},
299-
"nn.functional.bilinear": {f64},
300-
"nn.functional.linear": {f64},
301-
"pca_lowrank": {f64},
302-
"svd_lowrank": {f64},
303-
"tensordot": {f64},
304-
"triangular_solve": {f64},
305-
"svd": {f64},
306-
"qr": {f64},
307-
"pinverse": {f64},
308-
"ormqr": {f64},
309-
("norm", "nuc"): {f64},
310-
"lu": {f64},
311-
"lu_solve": {f64},
312-
"logdet": {f64},
313-
"linalg.tensorsolve": {f64},
314-
"linalg.tensorinv": {f64},
315-
"linalg.svdvals": {f64},
316-
"linalg.svd": {f64},
317-
"linalg.solve": {f64},
318-
"linalg.solve_triangular": {f64},
319-
"linalg.solve_ex": {f64},
320-
"linalg.slogdet": {f64},
321-
"linalg.qr": {f64},
322-
"linalg.pinv": {f64},
323-
("linalg.pinv", "hermitian"): {f64},
324278
("linalg.pinv", "singular"): {f64},
325-
"linalg.norm": {f64},
326-
("linalg.norm", "subgradients_at_zero"): {f64},
327-
"linalg.matrix_rank": {f64},
328-
("linalg.matrix_rank", "hermitian"): {f64},
329-
"linalg.matrix_norm": {f64},
330-
"linalg.lu": {f64},
331-
"linalg.lu_solve": {f64},
332-
"linalg.lu_factor": {f64},
333-
"linalg.lu_factor_ex": {f64},
334-
"linalg.lstsq": {f64},
335-
("linalg.lstsq", "grad_oriented"): {f64},
336-
"linalg.inv": {f64},
337-
"linalg.inv_ex": {f64},
338-
"linalg.householder_product": {f64},
339-
"linalg.eigvalsh": {f64},
340-
"linalg.eigh": {f64},
341-
"linalg.det": {f64},
342-
"linalg.cond": {f64},
343-
"geqrf": {f64},
344-
"cholesky_solve": {f64},
345-
"cholesky_inverse": {f64},
346279
# could not create a primitive
347-
"addbmm": {f64},
348-
"addmm": {f64},
349280
"addmv": {f64},
350281
# could not create a primitive descriptor for
351282
# a deconvolution forward propagation primitive

0 commit comments

Comments
 (0)
0