@@ -54,8 +54,6 @@ Tensor& addmm_out(
54
54
TORCH_CHECK (
55
55
!mat1.is_complex (), " Complex datatype matmul is not supported in oneDNN" );
56
56
57
- bool is_inplace = result.is_same (self);
58
-
59
57
std::vector<int64_t > result_shape = {mat1.size (0 ), mat2.size (1 )};
60
58
result.resize_ (result_shape);
61
59
@@ -86,34 +84,36 @@ Tensor& addmm_out(
86
84
Tensor bias = Tensor ();
87
85
onednn::Attr attr;
88
86
float beta_ = beta.to <float >();
87
+ float alpha_ = beta_ == 0 .f ? alpha.to <float >() : alpha.to <float >() / beta_;
89
88
if (beta_ == 0 .f ) {
90
- if (alpha.to <float >() != 1 .f ) {
89
+ attr.append_post_eltwise (1 .f , alpha_, 0 .f , attr.kind_with_linear );
90
+ } else if (alpha_ == 1 .f && beta_ == 1 .f && !result.is_same (self)) {
91
+ // if result and self are the same tensor, we use post op sum.
92
+ bias = self;
93
+ } else {
94
+ Tensor binary = self.dim () == 1 ? self.unsqueeze (0 ) : self;
95
+ bool inplace = binary.is_same (result);
96
+ if (inplace) {
91
97
attr.append_post_eltwise (
92
98
1 .f , alpha.to <float >(), 0 .f , attr.kind_with_linear );
93
- }
94
- } else {
95
- // We use post_binary here for adding self matrix.
96
- // To avoid wrong write, here we clone self for inplace case.
97
- if (alpha.to <float >() == 1 .f && beta_ == 1 .f ) {
98
- bias = is_inplace ? self.clone () : self;
99
+ attr.append_post_sum (beta_);
99
100
} else {
100
- Tensor binary;
101
- // unsqueeze(0) here is to handle mv cases.
102
- if (is_inplace)
103
- binary = self.dim () == 1 ? self.unsqueeze (0 ).clone () : self.clone ();
104
- else
105
- binary = self.dim () == 1 ? self.unsqueeze (0 ) : self;
101
+ if (at::native::onednn::is_broadcast (binary)) {
102
+ at::native::onednn::undo_broadcast (binary);
103
+ }
104
+ // in test_addmv_rowmajor_colmajor_incx_incy_lda, binary is a tensor with
105
+ // shape (5, 1) but stride(2, 2)
106
+ binary = at::native::onednn::is_onednn_matmul_strides (binary)
107
+ ? binary
108
+ : binary.contiguous ();
106
109
// Tensor binary = self.expand_as(result);
107
110
// For post-binary-add, onednn needs binary scale=1.f
108
111
// Thus we need the following transformation
109
112
// alpha * matmul(mat1, mat2) + beta * binary
110
113
// beta * (alpha/beta * matmul(src, wei) + binary)
111
- float alpha_ = alpha.to <float >() / beta_;
112
- if (alpha_ != 1 .f )
113
- attr.append_post_eltwise (1 .f , alpha_, 0 .f , attr.kind_with_linear );
114
- attr.append_post_binary (attr.kind_with_binary_add , binary);
115
- if (beta_ != 1 .f )
116
- attr.append_post_eltwise (1 .f , beta_, 0 .f , attr.kind_with_linear );
114
+ attr.append_post_eltwise (1 .f , alpha_, 0 .f , attr.kind_with_linear );
115
+ attr.append_post_binary <true >(attr.kind_with_binary_add , binary);
116
+ attr.append_post_eltwise (1 .f , beta_, 0 .f , attr.kind_with_linear );
117
117
}
118
118
}
119
119
onednn::matmul (result, mat1, mat2, bias, true , attr);
@@ -185,8 +185,6 @@ Tensor& baddbmm_out(
185
185
TORCH_CHECK (batch1.dim () == 3 , " expected 3D tensor" );
186
186
TORCH_CHECK (batch2.dim () == 3 , " expected 3D tensor" );
187
187
188
- bool is_inplace = result.is_same (input);
189
-
190
188
std::vector<int64_t > result_shape = {
191
189
batch1.size (0 ), batch1.size (1 ), batch2.size (2 )};
192
190
result.resize_ (result_shape);
@@ -216,27 +214,30 @@ Tensor& baddbmm_out(
216
214
// general case
217
215
onednn::Attr attr;
218
216
float beta_ = beta.to <float >();
217
+ float alpha_ = beta_ == 0 .f ? alpha.to <float >() : alpha.to <float >() / beta_;
219
218
Tensor binary;
220
219
if (beta_ == 0 .f ) {
221
- if (alpha.to <float >() != 1 .f ) {
222
- attr.append_post_eltwise (
223
- 1 .f , alpha.to <float >(), 0 .f , attr.kind_with_linear );
224
- }
220
+ attr.append_post_eltwise (1 .f , alpha_, 0 .f , attr.kind_with_linear );
225
221
} else {
226
- // We use post_binary here for adding self matrix.
227
- // To avoid wrong write, here we clone input for inplace case.
228
- if (is_inplace)
229
- binary = input.dim () < 3 ? input.unsqueeze (0 ).clone () : input.clone ();
230
- else
231
- binary = input.dim () < 3 ? input.unsqueeze (0 ) : input;
222
+ binary = input.dim () < 3 ? input.unsqueeze (0 ) : input;
232
223
// If input is a 1d tensor need be broadcasted, we need unsqueeze twice.
233
224
binary = binary.dim () < 3 ? binary.unsqueeze_ (0 ) : binary;
234
- float alpha_ = alpha.to <float >() / beta_;
235
- if (alpha_ != 1 .f )
225
+ bool inplace = binary.is_same (result);
226
+ if (inplace) {
227
+ attr.append_post_eltwise (
228
+ 1 .f , alpha.to <float >(), 0 .f , attr.kind_with_linear );
229
+ attr.append_post_sum (beta_);
230
+ } else {
231
+ if (at::native::onednn::is_broadcast (binary)) {
232
+ at::native::onednn::undo_broadcast (binary);
233
+ }
234
+ binary = at::native::onednn::is_onednn_matmul_strides (binary)
235
+ ? binary
236
+ : binary.contiguous ();
236
237
attr.append_post_eltwise (1 .f , alpha_, 0 .f , attr.kind_with_linear );
237
- attr.append_post_binary (attr.kind_with_binary_add , binary);
238
- if (beta_ != 1 .f )
238
+ attr.append_post_binary <true >(attr.kind_with_binary_add , binary);
239
239
attr.append_post_eltwise (1 .f , beta_, 0 .f , attr.kind_with_linear );
240
+ }
240
241
}
241
242
onednn::matmul (result, batch1, batch2, at::Tensor (), true , attr);
242
243
return result;
0 commit comments