8000 [Intel GPU] Avoid unnecessary copy when the dst of Matmul is non-cont… · pytorch/pytorch@91e7c79 · GitHub
[go: up one dir, main page]

Skip to content

Commit 91e7c79

Browse files
jianyizhpytorchmergebot
authored andcommitted
[Intel GPU] Avoid unnecessary copy when the dst of Matmul is non-contiguous (#144759)
We should not always call contiguous on the dst of matmul. We have already removed copy of matmul input in #143784 I also fixed an accuracy issue by using onednn sum post op instead of binary add in the case of inplace to avoid UT failure. Pull Request resolved: #144759 Approved by: https://github.com/EikanWang
1 parent 8ee84aa commit 91e7c79

File tree

6 files changed

+78
-59
lines changed

6 files changed

+78
-59
lines changed

aten/src/ATen/native/mkldnn/xpu/Blas.cpp

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ Tensor& addmm_out(
5454
TORCH_CHECK(
5555
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");
5656

57-
bool is_inplace = result.is_same(self);
58-
5957
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
6058
result.resize_(result_shape);
6159

@@ -86,34 +84,36 @@ Tensor& addmm_out(
8684
Tensor bias = Tensor();
8785
onednn::Attr attr;
8886
float beta_ = beta.to<float>();
87+
float alpha_ = beta_ == 0.f ? alpha.to<float>() : alpha.to<float>() / beta_;
8988
if (beta_ == 0.f) {
90-
if (alpha.to<float>() != 1.f) {
89+
attr.append_post_eltwise(1.f, alpha_, 0.f, attr.kind_with_linear);
90+
} else if (alpha_ == 1.f && beta_ == 1.f && !result.is_same(self)) {
91+
// if result and self are the same tensor, we use post op sum.
92+
bias = self;
93+
} else {
94+
Tensor binary = self.dim() == 1 ? self.unsqueeze(0) : self;
95+
bool inplace = binary.is_same(result);
96+
if (inplace) {
9197
attr.append_post_eltwise(
9298
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear);
93-
}
94-
} else {
95-
// We use post_binary here for adding self matrix.
96-
// To avoid wrong write, here we clone self for inplace case.
97-
if (alpha.to<float>() == 1.f && beta_ == 1.f) {
98-
bias = is_inplace ? self.clone() : self;
99+
attr.append_post_sum(beta_);
99100
} else {
100-
Tensor binary;
101-
// unsqueeze(0) here is to handle mv cases.
102-
if (is_inplace)
103-
binary = self.dim() == 1 ? self.unsqueeze(0).clone() : self.clone();
104-
else
105-
binary = self.dim() == 1 ? self.unsqueeze(0) : self;
101+
if (at::native::onednn::is_broadcast(binary)) {
102+
at::native::onednn::undo_broadcast(binary);
103+
}
104+
// in test_addmv_rowmajor_colmajor_incx_incy_lda, binary is a tensor with
105+
// shape (5, 1) but stride(2, 2)
106+
binary = at::native::onednn::is_onednn_matmul_strides(binary)
107+
? binary
108+
: binary.contiguous();
106109
// Tensor binary = self.expand_as(result);
107110
// For post-binary-add, onednn needs binary scale=1.f
108111
// Thus we need the following transformation
109112
// alpha * matmul(mat1, mat2) + beta * binary
110113
// beta * (alpha/beta * matmul(src, wei) + binary)
111-
float alpha_ = alpha.to<float>() / beta_;
112-
if (alpha_ != 1.f)
113-
attr.append_post_eltwise(1.f, alpha_, 0.f, attr.kind_with_linear);
114-
attr.append_post_binary(attr.kind_with_binary_add, binary);
115-
if (beta_ != 1.f)
116-
attr.append_post_eltwise(1.f, beta_, 0.f, attr.kind_with_linear);
114+
attr.append_post_eltwise(1.f, alpha_, 0.f, attr.kind_with_linear);
115+
attr.append_post_binary<true>(attr.kind_with_binary_add, binary);
116+
attr.append_post_eltwise(1.f, beta_, 0.f, attr.kind_with_linear);
117117
}
118118
}
119119
onednn::matmul(result, mat1, mat2, bias, true, attr);
@@ -185,8 +185,6 @@ Tensor& baddbmm_out(
185185
TORCH_CHECK(batch1.dim() == 3, "expected 3D tensor");
186186
TORCH_CHECK(batch2.dim() == 3, "expected 3D tensor");
187187

188-
bool is_inplace = result.is_same(input);
189-
190188
std::vector<int64_t> result_shape = {
191189
batch1.size(0), batch1.size(1), batch2.size(2)};
192190
result.resize_(result_shape);
@@ -216,27 +214,30 @@ Tensor& baddbmm_out(
216214
// general case
217215
onednn::Attr attr;
218216
float beta_ = beta.to<float>();
217+
float alpha_ = beta_ == 0.f ? alpha.to<float>() : alpha.to<float>() / beta_;
219218
Tensor binary;
220219
if (beta_ == 0.f) {
221-
if (alpha.to<float>() != 1.f) {
222-
attr.append_post_eltwise(
223-
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear);
224-
}
220+
attr.append_post_eltwise(1.f, alpha_, 0.f, attr.kind_with_linear);
225221
} else {
226-
// We use post_binary here for adding self matrix.
227-
// To avoid wrong write, here we clone input for inplace case.
228-
if (is_inplace)
229-
binary = input.dim() < 3 ? input.unsqueeze(0).clone() : input.clone();
230-
else
231-
binary = input.dim() < 3 ? input.unsqueeze(0) : input;
222+
binary = input.dim() < 3 ? input.unsqueeze(0) : input;
232223
// If input is a 1d tensor need be broadcasted, we need unsqueeze twice.
233224
binary = binary.dim() < 3 ? binary.unsqueeze_(0) : binary;
234-
float alpha_ = alpha.to<float>() / beta_;
235-
if (alpha_ != 1.f)
225+
bool inplace = binary.is_same(result);
226+
if (inplace) {
227+
attr.append_post_eltwise(
228+
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear);
229+
attr.append_post_sum(beta_);
230+
} else {
231+
if (at::native::onednn::is_broadcast(binary)) {
232+
at::native::onednn::undo_broadcast(binary);
233+
}
234+
binary = at::native::onednn::is_onednn_matmul_strides(binary)
235+
? binary
236+
: binary.contiguous();
236237
attr.append_post_eltwise(1.f, alpha_, 0.f, attr.kind_with_linear);
237-
attr.append_post_binary(attr.kind_with_binary_add, binary);
238-
if (beta_ != 1.f)
238+
attr.append_post_binary<true>(attr.kind_with_binary_add, binary);
239239
attr.append_post_eltwise(1.f, beta_, 0.f, attr.kind_with_linear);
240+
}
240241
}
241242
onednn::matmul(result, batch1, batch2, at::Tensor(), true, attr);
242243
return result;

aten/src/ATen/native/mkldnn/xpu/detail/Attr.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,26 @@ class Attr {
193193
}
194194

195195
// append binary post op
196+
template <bool is_matmul = false>
196197
Attr& append_post_binary(dnnl::algorithm algo, const at::Tensor& binary) {
197198
auto binary_ = binary.is_quantized() ? at::dequantize(binary) : binary;
198199
bool binary_is_channels_last =
199200
(binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
200201
binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d);
201202

202-
binary_ = binary_is_channels_last ? binary_ : binary_.contiguous();
203+
if constexpr (!is_matmul) {
204+
binary_ = binary_is_channels_last ? binary_ : binary_.contiguous();
205+
}
203206
dnnl::memory::desc md = get_onednn_md(binary_);
204207
auto expected_md = dnnl::memory::desc(
205208
md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::any);
206-
ops_params_.push_back(
207-
PostOpParam(binary_, md, expected_md, algo, kind_t::binary));
209+
if constexpr (is_matmul) {
210+
ops_params_.push_back(PostOpParam(binary_, md, md, algo, kind_t::binary));
211+
} else {
212+
ops_params_.push_back(
213+
PostOpParam(binary_, md, expected_md, algo, kind_t::binary));
214+
}
215+
208216
return *this;
209217
}
210218

aten/src/ATen/native/mkldnn/xpu/detail/Matmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ sycl::event matmul(
4242
m1 = is_onednn_matmul_strides(m1) ? m1 : m1.contiguous();
4343
m2 = is_onednn_matmul_strides(m2) ? m2 : m2.contiguous();
4444
at::Tensor dst =
45-
is_onednn_matmul_strides(result, true) ? result : result.contiguous();
45+
is_onednn_matmul_strides(result) ? result : result.contiguous();
4646

4747
int64_t m = dst.size(-2);
4848
int64_t n = dst.size(-1);

aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ void quantized_matmul(
132132
at::Tensor m1 = is_onednn_matmul_strides(mat1) ? mat1 : mat1.contiguous();
133133
at::Tensor m2 = is_onednn_matmul_strides(mat2) ? mat2 : mat2.contiguous();
134134
at::Tensor dst =
135-
is_onednn_matmul_strides(result, true) ? result : result.contiguous();
135+
is_onednn_matmul_strides(result) ? result : result.contiguous();
136136

137137
int64_t m = dst.size(-2);
138138
int64_t n = dst.size(-1);

aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,30 @@ void undo_broadcast_on_batch(at::Tensor& m1, at::Tensor& m2) {
257257
{tensor.stride(dim_m), tensor.stride(dim_n)})
258258
.unsqueeze(dim_b);
259259
}
260+
}
261+
262+
void undo_broadcast(at::Tensor& tensor) {
263+
// pytorch use stride = 0 for the dim to be broadcasted, but oneDNN only
264+
// support shape(dim) = 1 to implicitly indicate the broadcast dim.
265+
std::vector<int64_t> new_shape;
266+
std::vector<int64_t> new_strides;
267+
std::vector<int64_t> unsqueeze_dims;
268+
for (int i = 0; i < tensor.dim(); i++) {
269+
if (tensor.stride(i) == 0) {
270+
unsqueeze_dims.push_back(i);
271+
} else {
272+
new_shape.push_back(tensor.size(i));
273+
new_strides.push_back(tensor.stride(i));
274+
}
275+
}
276+
tensor = tensor.as_strided(new_shape, new_strides);
277+
for (size_t i = 0; i < unsqueeze_dims.size(); i++) {
278+
tensor = tensor.unsqueeze(unsqueeze_dims[i]);
279+
}
260280
return;
261281
}
262282

263-
bool is_onednn_matmul_strides(const at::Tensor& tensor, bool is_dst) {
264-
// TODO: We always call contiguous on dst.
265-
// delete it after fix the case that dst is transposed on batch and m dim.
266-
if (is_dst)
267-
return false;
283+
bool is_onednn_matmul_strides(const at::Tensor& tensor) {
268284
// https://oneapi-src.github.io/oneDNN/dev_guide_matmul.html
269285
// oneDNN matmul only support 2-dim and 3-dim
270286
// 2D src(Mxk), wei(KxN), dst(MxN)
@@ -289,17 +305,10 @@ bool is_onednn_matmul_strides(const at::Tensor& tensor, bool is_dst) {
289305
if (is_broadcast(tensor)) {
290306
return false;
291307
}
292-
if (is_dst) {
293-
// The memory format of the destination tensor should always be plain
294-
// with n axis contiguous
295-
if (strides[tensor_dim - 1] != 1)
296-
return false;
297-
} else {
298-
// the src and weight must have at least one of the axes
299-
// m or k and n or k contiguous (i.e., stride=1) respectively.
300-
if (strides[tensor_dim - 1] != 1 && strides[tensor_dim - 2] != 1)
301-
return false;
302-
}
308+
// the src and weight must have at least one of the axes
309+
// m or k and n or k contiguous (i.e., stride=1) respectively.
310+
if (strides[tensor_dim - 1] != 1 && strides[tensor_dim - 2] != 1)
311+
return false;
303312

304313
if (!onednn_strides_check(tensor))
305314
return false;

aten/src/ATen/native/mkldnn/xpu/detail/Utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ dnnl::memory::desc get_onednn_md(const at::Tensor& tensor);
4242
bool onednn_strides_check(const at::Tensor& src);
4343
bool is_broadcast(const at::Tensor& t);
4444
void undo_broadcast_on_batch(at::Tensor& m1, at::Tensor& m2);
45+
void undo_broadcast(at::Tensor& tensor);
4546

46-
bool is_onednn_matmul_strides(const at::Tensor& tensor, bool is_dst = false);
47+
bool is_onednn_matmul_strides(const at::Tensor& tensor);
4748

4849
bool is_broadcast_from_other_to_self(
4950
const at::Tensor& self,

0 commit comments

Comments
 (0)
0