8000 Revert "Add torch._scaled_mm for CPU (#139975)" · pytorch/pytorch@babb2dc · GitHub
[go: up one dir, main page]

Skip to content

Commit babb2dc

Browse files
Revert "Add torch._scaled_mm for CPU (#139975)"
This reverts commit 6f7e67c. Reverted #139975 on behalf of https://github.com/wdvr due to failing inductor mkldnn_pattern_matcher_cpu tests ([comment](#139975 (comment)))
1 parent 525ca80 commit babb2dc

File tree

12 files changed

+586
-922
lines changed

12 files changed

+586
-922
lines changed

aten/src/ATen/native/Blas.cpp

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77
#include <ATen/Config.h>
88

99
#include <ATen/native/mkldnn/Matmul.h>
10-
#include <ATen/native/mkldnn/Linear.h>
11-
#include <ATen/native/Resize.h>
12-
#if !defined(__s390x__) && !defined(__powerpc__)
13-
#include <cpuinfo.h>
14-
#endif
1510

1611
#ifndef AT_PER_OPERATOR_HEADERS
1712
#include <ATen/CPUFunctions.h>
@@ -29,12 +24,6 @@
2924
#include <ATen/ops/mv_native.h>
3025
#include <ATen/ops/scalar_tensor_native.h>
3126
#include <ATen/ops/vdot_native.h>
32-
#include <ATen/ops/_scaled_mm_native.h>
33-
#include <ATen/ops/mul.h>
34-
#include <ATen/ops/matmul.h>
35-
#endif
36-
#if AT_MKLDNN_ENABLED()
37-
#include <ideep.hpp>
3827
#endif
3928

4029
namespace at::meta {
@@ -233,79 +222,4 @@ Tensor vdot(const Tensor &self, const Tensor &other){
233222

234223
}
235224

236-
static Tensor&
237-
_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
238-
const Tensor& scale_a,
239-
const Tensor& scale_b,
240-
const std::optional<at::Tensor>& bias,
241-
const std::optional<at::Tensor>& scale_result,
242-
std::optional<c10::ScalarType> out_dtype,
243-
bool use_fast_accum,
244-
Tensor& out) {
245-
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
246-
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
247-
TORCH_CHECK(
248-
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
249-
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
250-
251-
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
252-
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
253-
" but got ", bias->numel());
254-
255-
// Check types
256-
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
257-
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
258-
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
259-
260-
auto mat1_c = mat1.contiguous();
261-
auto mat2_c = mat2.contiguous();
262-
IntArrayRef mat1_sizes = mat1_c.sizes();
263-
IntArrayRef mat2_sizes = mat2_c.sizes();
264-
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
265-
266-
float input_scale = scale_a.item<float>();
267-
float weight_scale = scale_b.item<float>();
268-
auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale);
269-
auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale);
270-
auto out_tmp = at::matmul(fp32_mat1, fp32_mat2);
271-
if (bias) {
272-
out_tmp.add_(bias.value());
273-
}
274-
out_tmp = out_tmp.to(out.scalar_type());
275-
out.copy_(out_tmp);
276-
return out;
277-
}
278-
279-
Tensor&
280-
_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2,
281-
const Tensor& scale_a,
282-
const Tensor& scale_b,
283-
const std::optional<at::Tensor>& bias,
284-
const std::optional<at::Tensor>& scale_result,
285-
std::optional<c10::ScalarType> out_dtype,
286-
bool use_fast_accum,
287-
Tensor& out) {
288-
#if AT_MKLDNN_ENABLED() && !(IDEEP_VERSION_MAJOR <= 2 || (IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR < 5))
289-
if (at::globalContext().userEnabledMkldnn() && cpuinfo_has_x86_amx_int8()) {
290-
return mkldnn_scaled_mm(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
291-
} else
292-
#endif
293-
{
294-
return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
295-
}
296-
}
297-
298-
Tensor
299-
_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b,
300-
const Tensor& scale_a,
301-
const Tensor& scale_b,
302-
const std::optional<at::Tensor>& bias,
303-
const std::optional<at::Tensor>& scale_result,
304-
std::optional<c10::ScalarType> out_dtype,
305-
bool use_fast_accum) {
306-
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
307-
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
308-
return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
309-
}
310-
311225
} // namespace at::native

aten/src/ATen/native/mkldnn/Linear.cpp

Lines changed: 1 addition & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <ATen/core/Tensor.h>
55
#include <torch/library.h>
66
#include <ATen/native/mkldnn/Linear.h>
7-
#include <ATen/native/Resize.h>
87

98
#ifndef AT_PER_OPERATOR_HEADERS
109
#include <ATen/Functions.h>
@@ -47,20 +46,9 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
4746
TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
4847
}
4948

50-
Tensor&
51-
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
52-
const Tensor& scale_a,
53-
const Tensor& scale_b,
54-
const std::optional<at::Tensor>& bias,
55-
const std::optional<at::Tensor>& scale_result,
56-
std::optional<c10::ScalarType> out_dtype,
57-
bool use_fast_accum,
58-
Tensor& out) {
59-
TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support");
60-
}
61-
6249
} // namespace at::native
6350

51+
6452
#else // AT_MKLDNN_ENABLED
6553

6654
#include <ATen/native/mkldnn/MKLDNNCommon.h>
@@ -459,119 +447,6 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
459447
TORCH_FN(mkldnn_linear_pointwise_binary));
460448
}
461449

462-
Tensor&
463-
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
464-
const Tensor& scale_a,
465-
const Tensor& scale_b,
466-
const std::optional<at::Tensor>& bias,
467-
const std::optional<at::Tensor>& scale_result,
468-
std::optional<c10::ScalarType> out_dtype,
469-
bool use_fast_accum,
470-
Tensor& out) {
471-
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
472-
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
473-
TORCH_CHECK(
474-
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
475-
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
476-
477-
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
478-
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
479-
" but got ", bias->numel());
480-
481-
// Check types
482-
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
483-
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
484-
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
485-
// TODO: This check of mat1 and mat2 must have the same data type will be removed after oneDNN v3.6.
486-
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "Expected mat1 and mat2 must have the same data type");
487-
488-
// Validation checks have passed lets resize the output to actual size
489-
auto mat1_c = mat1.contiguous();
490-
auto mat2_c = mat2.contiguous();
491-
IntArrayRef mat1_sizes = mat1_c.sizes();
492-
IntArrayRef mat2_sizes = mat2_c.sizes();
493-
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
494-
495-
float input_scale = scale_a.item<float>();
496-
float weight_scale = scale_b.item<float>();
497-
auto src = at::native::itensor_view_from_dense(mat1_c);
498-
auto weight_t = at::native::itensor_view_from_dense(mat2_c);
499-
bool with_bias = bias.has_value();
500-
int64_t K = mat1_sizes[1], M = mat1_sizes[0],
501-
N = mat2_sizes[1];
502-
503-
std::vector<int64_t> src_dims = {M, K};
504-
std::vector<int64_t> weight_dims = {K, N};
505-
std::vector<int64_t> dst_dims = {M, N};
506-
507-
ideep::tensor dst = at::native::itensor_view_from_dense(out);
508-
auto src_desc = ideep::tensor::desc(
509-
src_dims,
510-
get_mkldnn_dtype(mat1.scalar_type()),
511-
ideep::format_tag::any);
512-
auto weights_desc = ideep::tensor::desc(
513-
weight_dims,
514-
get_mkldnn_dtype(mat2.scalar_type()),
515-
ideep::format_tag::any);
516-
auto dst_desc = ideep::tensor::desc(
517-
dst_dims,
518-
get_mkldnn_dtype(out.scalar_type()),
519-
ideep::format_tag::any);
520-
ideep::tensor onednn_bias;
521-
if (with_bias) {
522-
auto bias_value = bias.value();
523-
if (bias_value.dim() == 1) {
524-
auto b_reshape = bias_value.reshape({1, bias_value.size(0)});
525-
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
526-
} else {
527-
onednn_bias = at::native::itensor_view_from_dense(bias_value);
528-
}
529-
}
530-
auto bias_desc = ideep::tensor::desc();
531-
if (with_bias) {
532-
bias_desc = ideep::tensor::desc(onednn_bias.get_dims(),
533-
get_mkldnn_dtype(bias.value().scalar_type()),
534-
ideep::format_tag::any);
535-
}
536-
auto op_attr = ideep::attr_t();
537-
if (input_scale != 1.0f) {
538-
op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
539-
}
540-
if (weight_scale != 1.0f) {
541-
op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
542-
}
543-
544-
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
545-
auto engine = ideep::engine::cpu_engine();
546-
dnnl::matmul::primitive_desc primitive_desc = with_bias
547-
? dnnl::matmul::primitive_desc(
548-
engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr)
549-
: dnnl::matmul::primitive_desc(
550-
engine, src_desc, weights_desc, dst_desc, op_attr);
551-
auto primitive = dnnl::matmul(primitive_desc);
552-
553-
// Prepare args and execute primitive
554-
ideep::tensor scratchpad(primitive_desc.scratchpad_desc());
555-
ideep::exec_args args;
556-
args.insert({DNNL_ARG_SRC, src});
557-
args.insert({DNNL_ARG_WEIGHTS, weight_t});
558-
args.insert({DNNL_ARG_DST, dst});
559-
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
560-
if (with_bias) {
561-
args.insert({DNNL_ARG_BIAS, onednn_bias});
562-
}
563-
ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale));
564-
ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale));
565-
566-
if (input_scale != 1.0f) {
567-
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
568-
}
569-
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
570-
571-
primitive.execute(ideep::stream::default_stream(), args);
572-
return out;
573-
}
574-
575450
} // namespace at
576451

577452
#endif // AT_MKLDNN_ENABLED

aten/src/ATen/native/mkldnn/Linear.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,3 @@ C10_API Tensor mkl_linear(
3535
} // namespace at
3636

3737
#endif // AT_MKLDNN_ENABLED()
38-
39-
namespace at::native {
40-
Tensor&
41-
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
42-
const Tensor& scale_a,
43-
const Tensor& scale_b,
44-
const std::optional<at::Tensor>& bias,
45-
const std::optional<at::Tensor>& scale_result,
46-
std::optional<c10::ScalarType> out_dtype,
47-
bool use_fast_accum,
48-
Tensor& out);
49-
} // namespace at::native

aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,9 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
5454
case ScalarType::Byte:
5555
return ideep::tensor::data_type::u8;
5656
case ScalarType::BFloat16:
57-
#if !(IDEEP_VERSION_MAJOR <= 2 || (IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR < 5))
5857
return ideep::tensor::data_type::bf16;
5958
case ScalarType::Half:
6059
return ideep::tensor::data_type::f16;
61-
case ScalarType::Float8_e4m3fn:
62-
return ideep::tensor::data_type::f8_e4m3;
63-
case ScalarType::Float8_e5m2:
64-
return ideep::tensor::data_type::f8_e5m2;
65-
#endif
6660
default:
6761
TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
6862
}
@@ -167,26 +161,8 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data
167161
const_cast<void*>(tensor.const_data_ptr()) :
168162
tensor.data_ptr()};
169163
}
170-
#if !(IDEEP_VERSION_MAJOR <= 2 || (IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR < 5))
171-
else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) {
172-
return {{tensor.sizes().vec(),
173-
ideep::tensor::data_type::f8_e4m3,
174-
tensor.strides().vec()},
175-
from_const_data_ptr ?
176-
const_cast<void*>(tensor.const_data_ptr()) :
177-
tensor.data_ptr()};
178-
}
179-
else if (tensor.scalar_type() == ScalarType::Float8_e5m2) {
180-
return {{tensor.sizes().vec(),
181-
ideep::tensor::data_type::f8_e5m2,
182-
tensor.strides().vec()},
183-
from_const_data_ptr ?
184-
const_cast<void*>(tensor.const_data_ptr()) :
185-
tensor.data_ptr()};
186-
}
187-
#endif
188164
else {
189-
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input");
165+
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input");
190166
}
191167
}
192168

aten/src/ATen/native/native_functions.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7066,13 +7066,11 @@
70667066
- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
70677067
variants: function
70687068
dispatch:
7069-
CPU: _scaled_mm_cpu
70707069
CUDA: _scaled_mm_cuda
70717070

70727071
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
70737072
variants: function
70747073
dispatch:
7075-
CPU: _scaled_mm_out_cpu
70767074
CUDA: _scaled_mm_out_cuda
70777075

70787076
# NOTE [ Sparse: autograd and API ]

0 commit comments

Comments
 (0)
0