8000 Enable TF32 as fp32 internal precision for matmul · pytorch/pytorch@e3f1ea3 · GitHub
[go: up one dir, main page]

Skip to content

Commit e3f1ea3

Browse files
committed
Enable TF32 as fp32 internal precision for matmul
Enable TF32 as fp32 internal precision for Linear Enable TF32 as fp32 internal precision for conv ghstack-source-id: 5365a8c Pull Request resolved: #157520
1 parent 1ea9cde commit e3f1ea3

File tree

14 files changed

+266
-141
lines changed

14 files changed

+266
-141
lines changed

aten/src/ATen/Context.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace {
2727
These const variables defined the fp32 precisions for different backend
2828
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
2929
prevision from "ieee", "tf32", "bf16" and "none". The "ieee" precision means
30-
IEEE standard floating point format "tf32" and "bf16" means we are allowed to
30+
IEEE standard floating point format, "tf32" and "bf16" means we are allowed to
3131
use "tf32" or "bf16" as internal computation data types for fp32 computations.
3232
And "none" means it is override-able by parent's node
3333
@@ -40,7 +40,7 @@ namespace {
4040
*/
4141
const std::map<std::string, std::vector<std::string>> _fp32_precisions = {
4242
{"generic", {{"ieee", "tf32", "bf16", "none"}}},
43-
{"mkldnn", {{"ieee", "bf16", "none"}}},
43+
{"mkldnn", {{"ieee", "tf32", "bf16", "none"}}},
4444
{"cuda", {{"ieee", "tf32", "none"}}}};
4545

4646
// Check whether the backend and op are legal
@@ -368,6 +368,9 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
368368
invalid = invalid ||
369369
(float32Precision("mkldnn", "matmul") == "bf16" &&
370370
float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM);
371+
invalid = invalid ||
372+
(float32Precision("mkldnn", "matmul") == "tf32" &&
373+
float32_matmul_precision != at::Float32MatmulPrecision::HIGH);
371374
TORCH_CHECK(
372375
!invalid,
373376
"PyTorch is checking the matmul precision without a specific backend name,",
@@ -401,7 +404,7 @@ void Context::setFloat32MatmulPrecision(const std::string &s) {
401404
} else if (s_ == "high") {
402405
float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
403406
setFloat32Precision("cuda", "matmul", "tf32");
404-
setFloat32Precision("mkldnn", "matmul", "ieee");
407+
setFloat32Precision("mkldnn", "matmul", "tf32");
405408
return true;
406409
} else if (s_ == "medium") {
407410
float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;

aten/src/ATen/native/CPUBlas.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ void gemm(
202202
float *c, int64_t ldc) {
203203
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
204204
#if AT_MKLDNN_ENABLED()
205-
if (mkldnn_bf32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
205+
if (mkldnn_reduced_f32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
206206
return;
207207
}
208208
#endif

aten/src/ATen/native/mkldnn/Conv.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ static bool mkldnn_conv_enabled_fpmath_mode_bf16(){
160160
mkldnn_bf16_device_check();
161161
}
162162

163+
static bool mkldnn_conv_enabled_fpmath_mode_tf32(){
164+
return at::globalContext().float32Precision("mkldnn", "conv") == "tf32" &&
165+
cpuinfo_has_x86_amx_fp16();
166+
}
163167

164168
static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) {
165169
auto memory_format = at::MemoryFormat::Contiguous;
@@ -271,6 +275,10 @@ static Tensor _mkldnn_convolution(
271275
input_t.scalar_type() == at::kFloat) {
272276
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
273277
}
278+
if (mkldnn_conv_enabled_fpmath_mode_tf32() &&
279+
input_t.scalar_type() == at::kFloat) {
280+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32);
281+
}
274282
_mkldnn_convolution_out(
275283
input_t,
276284
weight_t,
@@ -455,6 +463,9 @@ Tensor mkldnn_convolution_pointwise_binary(
455463
if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){
456464
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
457465
}
466+
if (mkldnn_conv_enabled_fpmath_mode_tf32() && input_t.scalar_type() ==at::kFloat){
467+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32);
468+
}
458469

459470
if (bias.defined()) {
460471
const ideep::tensor b = itensor_from_tensor(bias);
@@ -597,6 +608,10 @@ Tensor& mkldnn_convolution_pointwise_binary_(
597608
input_t.scalar_type() == at::kFloat) {
598609
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
599610
}
611+
if (mkldnn_conv_enabled_fpmath_mode_tf32() &&
612+
input_t.scalar_type() == at::kFloat) {
613+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32);
614+
}
600615
_mkldnn_convolution_out(
601616
input_t,
602617
weight_t,
@@ -718,6 +733,9 @@ Tensor _mkldnn_convolution_transpose(
718733
if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){
719734
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
720735
}
736+
if (mkldnn_conv_enabled_fpmath_mode_tf32() && input_t.scalar_type() ==at::kFloat){
737+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32);
738+
}
721739

722740
if (bias.defined()) {
723741
const ideep::tensor b = itensor_from_tensor(bias, /*from_const_data_ptr*/true);
@@ -808,6 +826,10 @@ Tensor mkldnn_convolution_backward_input(
808826
weight.scalar_type() == at::kFloat) {
809827
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
810828
}
829+
if (mkldnn_conv_enabled_fpmath_mode_tf32() &&
830+
weight.scalar_type() == at::kFloat) {
831+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32);
832+
}
811833
ideep::convolution_backward_data::compute_v2(
812834
grad_y,
813835
w,
@@ -828,6 +850,11 @@ Tensor mkldnn_convolution_backward_input(
828850
TORCH_WARN_ONCE(
829851
"Unexpected ideep version to support fpmath_mode_bf16, please update ideep version to align with pytorch main branch");
830852
}
853+
if (mkldnn_conv_enabled_fpmath_mode_tf32() &&
854+
weight.scalar_type() == at::kFloat) {
855+
TORCH_WARN_ONCE(
856+
"Unexpected ideep version to support fpmath_mode_tf32, please update ideep version to align with pytorch main branch");
857+
}
831858
#endif
832859

833860
if (grad_output.is_mkldnn()) {
@@ -858,6 +885,10 @@ std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
858885
input.scalar_type() == at::kFloat) {
859886
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
860887
}
888+
if (mkldnn_conv_enabled_fpmath_mode_tf32() &&
889+
input.scalar_type() == at::kFloat) {
890+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32);
891+
}
861892
if (bias_defined) {
862893
ideep::convolution_backward_weights::compute_v2(
863894
x,
@@ -1011,6 +1042,10 @@ Tensor mkldnn_convolution_transpose_backward_input(
10111042
weight.scalar_type() == at::kFloat) {
10121043
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
10131044
}
1045+
if (mkldnn_conv_enabled_fpmath_mode_tf32() &&
1046+
weight.scalar_type() == at::kFloat) {
1047+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32);
1048+
}
10141049
ideep::convolution_transpose_backward_data::compute_v3(
10151050
grad_y,
10161051
w,
@@ -1053,6 +1088,10 @@ std::tuple<Tensor,Tensor> mkldnn_convolution_transpose_backward_weights(
10531088
input.scalar_type() == at::kFloat) {
10541089
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
10551090
}
1091+
if (mkldnn_conv_enabled_fpmath_mode_tf32() &&
1092+
input.scalar_type() == at::kFloat) {
1093+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32);
1094+
}
10561095
if (bias_defined) {
10571096
ideep::convolution_transpose_backward_weights::compute_v3(
10581097
x,

aten/src/ATen/native/mkldnn/Linear.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ static bool use_mkldnn_bf32_linear() {
7373
mkldnn_bf16_device_check();
7474
}
7575

76+
static bool use_mkldnn_tf32_linear() {
77+
return at::globalContext().float32Precision("mkldnn", "matmul") == "tf32" &&
78+
cpuinfo_has_x86_amx_fp16();
79+
}
80+
7681
Tensor mkldnn_linear(
7782
const Tensor& self,
7883
const Tensor& weight_t, const std::optional<Tensor>& bias_opt) {
@@ -259,6 +264,9 @@ Tensor mkldnn_linear_pointwise(
259264
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
260265
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
261266
}
267+
if (u 3044 se_mkldnn_tf32_linear() && input_t.scalar_type() == at::kFloat){
268+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32);
269+
}
262270
if (mkldnn_bias.has_value()) {
263271
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
264272
mkldnn_input,
@@ -352,6 +360,10 @@ Tensor mkldnn_linear_pointwise_binary(
352360
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
353361 }
354362

363+
if (use_mkldnn_tf32_linear() && input_t.scalar_type() == at::kFloat){
364+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32);
365+
}
366+
355367
if (mkldnn_bias.has_value()) {
356368
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
357369
mkldnn_input,

aten/src/ATen/native/mkldnn/Matmul.cpp

Lines changed: 58 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2-
#include <ATen/core/Tensor.h>
32
#include <ATen/Config.h>
43
#include <ATen/Context.h>
4+
#include <ATen/Dispatch.h>
5+
#include <ATen/core/Tensor.h>
56
#include <ATen/native/mkldnn/Matmul.h>
67

78
#if !AT_MKLDNN_ENABLED()
@@ -53,7 +54,7 @@ bool mkldnn_fp16_gemm(
5354
c10::Half *c, int64_t ldc) {
5455
return false;
5556
}
56-
bool mkldnn_bf32_gemm(
57+
bool mkldnn_reduced_f32_gemm(
5758
TransposeType transa, TransposeType transb,
5859
int64_t m, int64_t n, int64_t k,
5960
float alpha,
@@ -85,6 +86,13 @@ void mkldnn_matmul_i8i8i32(
8586
TORCH_INTERNAL_ASSERT(false, __func__, ": ATen not compiled with MKLDNN support");
8687
}
8788

89+
bool use_mkldnn_tf32_matmul(
90+
const Tensor& mat1,
3044 91+
const Tensor& mat2,
92+
const Tensor& result) {
93+
return false;
94+
}
95+
8896
} // namespace at::native
8997

9098

@@ -107,6 +115,10 @@ static bool use_mkldnn_bf32_matmul() {
107115
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision("mkldnn", "matmul") == "bf16";
108116
}
109117

118+
static bool use_mkldnn_tf32_matmul() {
119+
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision("mkldnn", "matmul") == "tf32";
120+
}
121+
110122
// returns an ideep::tensor
111123
// - dims: shape e.g: {M,N}
112124
// - idtype: ideep data type e.g: (f32, bf16, f16)
@@ -144,7 +156,8 @@ mkldnn_gemm(
144156
bool bf16_usable = std::is_same_v<scalar_t, c10::BFloat16> && use_mkldnn_bf16_matmul();
145157
bool fp16_usable = std::is_same_v<scalar_t, c10::Half> && use_mkldnn_fp16_matmul();
146158
bool bf32_usable = std::is_same_v<scalar_t, float> && use_mkldnn_bf32_matmul();
147-
if ( !(bf16_usable || fp16_usable || bf32_usable) ||
159+
bool tf32_usable = std::is_same_v<scalar_t, float> && use_mkldnn_tf32_matmul();
160+
if ( !(bf16_usable || fp16_usable || bf32_usable || tf32_usable) ||
148161
(m * n * k <= 16 * 16 * 16) || (alpha == 0.0f)) {
149162
return false;
150163
}
@@ -155,6 +168,7 @@ mkldnn_gemm(
155168
op_attr = ideep::attr_t::fuse_sum();
156169
}
157170
if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path
171+
if (tf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); // tf32 path
158172

159173
// NOTE: View as c-contiguous to avoid extra reordering in mkldnn
160174
// Use identity: C = AB <=> C^T = B^T A^T
@@ -281,7 +295,7 @@ bool mkldnn_fp16_gemm(
281295
return mkldnn_gemm<c10::Half>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
282296
}
283297

284-
bool mkldnn_bf32_gemm(
298+
bool mkldnn_reduced_f32_gemm(
285299
TransposeType transa, TransposeType transb,
286300
int64_t m, int64_t n, int64_t k,
287301
float alpha,
@@ -339,13 +353,15 @@ void mkldnn_matmul(
339353
auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2;
340354
auto result_unsqueezed = result.dim() == 1 ? result.unsqueeze(1) : result;
341355
bool bf32_usable = mat1.scalar_type() == at::kFloat && use_mkldnn_bf32_matmul();
356+
bool tf32_usable = mat1.scalar_type() == at::kFloat && use_mkldnn_tf32_matmul();
342357

343358
ideep::attr_t op_attr;
344359
// "addmm", "addbmm" "baddbmm" in pytorch allow bias to be 2-D or 3-D tensor
345360
// but mkldnn matmul primitive only support bias be 1-D tensors
346361
// to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over
347362
if (beta != 0.0f) op_attr = ideep::attr_t::fuse_sum();
348363
if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path
364+
if (tf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); // tf32 path
349365
// If alpha = 0, dose not need actually do gemm computation
350366
if (alpha == 0)
351367
return;
@@ -412,70 +428,56 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
412428
}
413429
}
414430

415-
bool use_mkldnn_bf16_matmul(
431+
template <typename T>
432+
bool use_mkldnn_typed_matmul(
416433
const Tensor& mat1,
417434
const Tensor& mat2,
418435
const Tensor& result) {
436+
bool dtype_check = false;
437+
if constexpr (std::is_same_v<T, c10::BFloat16>) {
419438
#if defined(__aarch64__)
420-
if (mkldnn_bf16_device_check_arm()) {
421-
//onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1
422-
//so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well
423-
return (
424-
use_mkldnn_bf16_matmul() &&
425-
(mat1.scalar_type() == mat2.scalar_type()) && (!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
426-
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
427-
mat1.numel() != 0 &&
428-
mat2.numel() != 0 &&
429-
checksize(mat1, mat2));
430-
} else
439+
if (mkldnn_bf16_device_check_arm()) {
440+
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
441+
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
442+
// inputs, allow it for float as well
443+
dtype_check = use_mkldnn_bf16_matmul() &&
444+
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16));
445+
}
446+
#else
447+
dtype_check = dtype_check && use_mkldnn_bf16_matmul() &&
448+
(mat1.scalar_type() == kBFloat16);
431449
#endif
432-
{
433-
return (
434-
use_mkldnn_bf16_matmul() &&
435-
mat1.scalar_type() == kBFloat16 &&
436-
mat2.scalar_type() == kBFloat16 &&
437-
(!result.defined() || result.scalar_type() == kBFloat16) &&
438-
mat1.numel() != 0 &&
439-
mat2.numel() != 0 &&
440-
checksize(mat1, mat2));
450+
} else if constexpr (std::is_same_v<T, c10::Half>) {
451+
dtype_check = dtype_check && use_mkldnn_fp16_matmul() &&
452+
(mat1.scalar_type() == kHalf);
453+
} else if constexpr (std::is_same_v<T, float>) {
454+
dtype_check = dtype_check &&
455+
(use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) &&
456+
(mat1.scalar_type() == kFloat);
441457
}
442-
}
443-
444-
bool use_mkldnn_fp16_matmul(
445-
const Tensor& mat1,
446-
const Tensor& mat2,
447-
const Tensor& result) {
448-
449-
return (
450-
use_mkldnn_fp16_matmul() &&
451-
mat1.scalar_type() == kHalf &&
452-
mat2.scalar_type() == kHalf &&
453-
(!result.defined() || result.scalar_type() == kHalf) &&
454-
mat1.numel() != 0 &&
455-
mat2.numel() != 0 &&
456-
checksize(mat1, mat2));
457-
}
458-
459-
bool use_mkldnn_bf32_matmul(
460-
const Tensor& mat1,
461-
const Tensor& mat2,
462-
const Tensor& result) {
463-
464-
return (
465-
use_mkldnn_bf32_matmul() &&
466-
mat1.scalar_type() == kFloat &&
467-
mat2.scalar_type() == kFloat &&
468-
(!result.defined() || result.scalar_type() == kFloat) &&
469-
mat1.numel() != 0 &&
470-
mat2.numel() != 0 &&
471-
checksize(mat1, mat2));
458+
if (!dtype_check) {
459+
return false;
460+
}
461+
bool size_check =
462+
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2);
463+
dtype_check = (mat1.scalar_type() == mat2.scalar_type()) &&
464+
(!result.defined() || result.scalar_type() == mat1.scalar_type());
465+
return dtype_check && size_check;
472466
}
473467

474468
bool use_mkldnn_matmul(
475469
const Tensor& mat1,
476470
const Tensor& mat2,
477471
const Tensor& result) {
478-
return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result));
472+
auto mat1_type = mat1.scalar_type();
473+
if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) {
474+
return false;
475+
}
476+
AT_DISPATCH_FLOATING_TYPES_AND2(
477+
kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] {
478+
return use_mkldnn_typed_matmul<scalar_t>(mat1, mat2, result);
479+
});
480+
return false;
479481
}
480482

481483
static void _mkldnn_matmul_i8i8i32_with_primitive(

0 commit comments

Comments
 (0)
0