@@ -27,7 +27,7 @@ Tensor& addmm_out(
27
27
const Tensor& mat2,
28
28
const Scalar& beta,
29
29
const Scalar& alpha,
30
- at:: Tensor& result) {
30
+ Tensor& result) {
31
31
checkBackend (" addmm_out" , {result, self, mat1, mat2}, Backend::XPU);
32
32
TORCH_CHECK (
33
33
mat1.dim () == 2 , " mat1 must be a matrix, got " , mat1.dim (), " -D tensor" );
@@ -50,11 +50,11 @@ Tensor& addmm_out(
50
50
mat1.dtype (),
51
51
" != " ,
52
52
mat2.dtype ())
53
- // complex/double case
54
- if (mat1. is_complex () || mat1. scalar_type () == ScalarType::Double) {
55
- TORCH_CHECK (
56
- false , " Double and complex datatype matmul is not supported in oneDNN " );
57
- }
53
+ // complex case
54
+ TORCH_CHECK (
55
+ !mat1. is_complex (), " Complex datatype matmul is not supported in oneDNN " );
56
+
57
+ bool is_inplace = result. is_same (self);
58
58
59
59
std::vector<int64_t > result_shape = {mat1.size (0 ), mat2.size (1 )};
60
60
result.resize_ (result_shape);
@@ -92,10 +92,17 @@ Tensor& addmm_out(
92
92
1 .f , alpha.to <float >(), 0 .f , attr.kind_with_linear );
93
93
}
94
94
} else {
95
+ // We use post_binary here for adding self matrix.
96
+ // To avoid wrong write, here we clone self for inplace case.
95
97
if (alpha.to <float >() == 1 .f && beta_ == 1 .f ) {
96
- bias = self;
98
+ bias = is_inplace ? self. clone () : self;
97
99
} else {
98
- Tensor binary = self.dim () == 1 ? self.unsqueeze (0 ) : self;
100
+ Tensor binary;
101
+ // unsqueeze(0) here is to handle mv cases.
102
+ if (is_inplace)
103
+ binary = self.dim () == 1 ? self.unsqueeze (0 ).clone () : self.clone ();
104
+ else
105
+ binary = self.dim () == 1 ? self.unsqueeze (0 ) : self;
99
106
// Tensor binary = self.expand_as(result);
100
107
// For post-binary-add, onednn needs binary scale=1.f
101
108
// Thus we need the following transformation
@@ -159,26 +166,13 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
159
166
return result;
160
167
}
161
168
162
- if (self.is_complex () || self.scalar_type () == ScalarType::Double) {
163
- TORCH_CHECK (
164
- false , " Double and complex datatype matmul is not supported in oneDNN" );
165
- }
169
+ TORCH_CHECK (
170
+ !self.is_complex (), " Complex datatype matmul is not supported in oneDNN" );
166
171
167
172
onednn::matmul (result, self, mat2, Tensor (), true , onednn::Attr ());
168
173
return result;
169
174
}
170
175
171
- Tensor mm (const Tensor& self, const Tensor& mat2) {
172
- auto result = at::empty ({0 }, self.options ());
173
- xpu::mm_out (self, mat2, result);
174
- return result;
175
- }
176
-
177
- Tensor mv (const Tensor& self, const Tensor& vec) {
178
- Tensor result = at::empty ({self.size (0 )}, self.options ());
179
- return at::addmv_ (result, self, vec, 0 , 1 );
180
- }
181
-
182
176
// result = beta * input + alpha * (batch1 @ batch2)
183
177
Tensor& baddbmm_out (
184
178
const Tensor& input,
@@ -191,6 +185,8 @@ Tensor& baddbmm_out(
191
185
TORCH_CHECK (batch1.dim () == 3 , " expected 3D tensor" );
192
186
TORCH_CHECK (batch2.dim () == 3 , " expected 3D tensor" );
193
187
188
+ bool is_inplace = result.is_same (input);
189
+
194
190
std::vector<int64_t > result_shape = {
195
191
batch1.size (0 ), batch1.size (1 ), batch2.size (2 )};
196
192
result.resize_ (result_shape);
@@ -212,11 +208,10 @@ Tensor& baddbmm_out(
212
208
" but got:" ,
213
209
input.sizes ());
214
210
215
- // complex and double case
216
- if (batch1.is_complex () || batch2.scalar_type () == ScalarType::Double) {
217
- TORCH_CHECK (
218
- false , " Double and complex datatype matmul is not supported in oneDNN" );
219
- }
211
+ // complex case
212
+ TORCH_CHECK (
213
+ !batch1.is_complex (),
214
+ " Complex datatype matmul is not supported in oneDNN" );
220
215
221
216
// general case
222
217
onednn::Attr attr;
@@ -228,7 +223,13 @@ Tensor& baddbmm_out(
228
223
1 .f , alpha.to <float >(), 0 .f , attr.kind_with_linear );
229
224
}
230
225
} else {
231
- binary = input.dim () < 3 ? input.unsqueeze (0 ) : input;
226
+ // We use post_binary here for adding self matrix.
227
+ // To avoid wrong write, here we clone input for inplace case.
228
+ if (is_inplace)
229
+ binary = input.dim () < 3 ? input.unsqueeze (0 ).clone () : input.clone ();
230
+ else
231
+ binary = input.dim () < 3 ? input.unsqueeze (0 ) : input;
232
+ // If input is a 1d tensor need be broadcasted, we need unsqueeze twice.
232
233
binary = binary.dim () < 3 ? binary.unsqueeze_ (0 ) : binary;
233
234
float alpha_ = alpha.to <float >() / beta_;
234
235
if (alpha_ != 1 .f )
@@ -241,109 +242,6 @@ Tensor& baddbmm_out(
241
242
return result;
242
243
}
243
244
244
- Tensor& baddbmm_ (
245
- Tensor& self,
246
- const Tensor& batch1,
247
- const Tensor& batch2,
248
- const Scalar& beta,
249
- const Scalar& alpha) {
250
- TORCH_CHECK (
251
- self.dtype () == batch1.dtype (),
252
- " Input dtypes must be the same, got: input " ,
253
- self.dtype (),
254
- " , batch1: " ,
255
- batch1.dtype (),
256
- " , batch2: " ,
257
- batch2.dtype ());
258
- return at::native::xpu::baddbmm_out (self, batch1, batch2, beta, alpha, self);
259
- }
260
-
261
- Tensor baddbmm (
262
- const Tensor& input,
263
- const Tensor& batch1,
264
- const Tensor& batch2,
265
- const Scalar& beta,
266
- const Scalar& alpha) {
267
- Tensor r = at::empty ({0 }, input.options ());
268
- TORCH_CHECK (
269
- input.dtype () == batch1.dtype (),
270
- " Input dtypes must be the same, got: input " ,
271
- input.dtype (),
272
- " , batch1: " ,
273
- batch1.dtype (),
274
- " , batch2: " ,
275
- batch2.dtype ());
276
- r = at::native::xpu::baddbmm_out (input, batch1, batch2, beta, alpha, r);
277
- return r;
278
- }
279
-
280
- Tensor& addbmm_out (
281
- const Tensor& self,
282
- const Tensor& batch1,
283
- const Tensor& batch2,
284
- const Scalar& beta,
285
- const Scalar& alpha,
286
- Tensor& out) {
287
- checkBackend (" addbmm_out" , {out, self, batch1, batch2}, Backend::XPU);
288
- TORCH_CHECK (
289
- batch1.dim () == 3 && batch2.dim () == 3 ,
290
- " Batch tensors should be 3D, got dimensions " ,
291
- batch1.dim (),
292
- " and " ,
293
- batch2.dim ());
294
- if (self.is_complex () || self.scalar_type () == ScalarType::Double) {
295
- TORCH_CHECK (
296
- false , " Double and complex datatype matmul is not supported in oneDNN" );
297
- }
298
-
299
- out.resize_ ({batch1.size (1 ), batch2.size (2 )});
300
- if (alpha.to <float >() == 0 .f || batch1.numel () == 0 || batch2.numel () == 0 ) {
301
- out.resize_ ({batch1.size (1 ), batch2.size (2 )});
302
- if (out.numel () == 0 )
303
- return out;
304
-
305
- if (self.defined () && beta.to <float >() != 0 .f ) {
306
- out = at::mul_out (
307
- out, self, at::native::wrapped_scalar_tensor (at::Scalar (beta)));
308
- } else {
309
- out.zero_ ();
310
- }
311
- return out;
312
- }
313
-
314
- Tensor b1;
315
- if (batch1.size (0 ) > 1 ) {
316
- b1 = batch1.transpose (0 , 1 ).contiguous ().view ({batch1.size (1 ), -1 });
317
- } else {
318
- b1 = batch1.contiguous ().view ({batch1.size (1 ), -1 });
319
- }
320
- auto b2 = batch2.contiguous ().view ({-1 , batch2.size (2 )});
321
- at::native::xpu::addmm_out (self, b1, b2, beta, alpha, out);
322
-
323
- return out;
324
- }
325
-
326
- Tensor& addbmm_ (
327
- Tensor& self,
328
- const Tensor& batch1,
329
- const Tensor& batch2,
330
- const Scalar& beta,
331
- const Scalar& alpha) {
332
- at::native::xpu::addbmm_out (self, batch1, batch2, beta, alpha, self);
333
- return self;
334
- }
335
-
336
- Tensor addbmm (
337
- const Tensor& self,
338
- const Tensor& batch1,
339
- const Tensor& batch2,
340
- const Scalar& beta,
341
- const Scalar& alpha) {
342
- Tensor out = at::empty ({0 }, self.options ());
343
- at::native::xpu::addbmm_out (self, batch1, batch2, beta, alpha, out);
344
- return out;
345
- }
346
-
347
245
Tensor& bmm_out (const Tensor& self, const Tensor& batch2, Tensor& result) {
348
246
checkBackend (" bmm_out" , {result, self, batch2}, Backend::XPU);
349
247
TORCH_CHECK (self.dim () == 3 , " expected 3D tensor" );
@@ -356,10 +254,8 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
356
254
return result;
357
255
}
358
256
359
- if (self.is_complex () || self.scalar_type () == ScalarType::Double) {
360
- TORCH_CHECK (
361
- false , " Double and complex datatype matmul is not supported in oneDNN" );
362
- }
257
+ TORCH_CHECK (
258
+ !self.is_complex (), " Complex datatype matmul is not supported in oneDNN" );
363
259
onednn::matmul (result, self, batch2, at::Tensor (), true , onednn::Attr ());
364
260
return result;
365
261
}
0 commit comments