11#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2- #include < ATen/core/Tensor.h>
32#include < ATen/Config.h>
43#include < ATen/Context.h>
4+ #include < ATen/Dispatch.h>
5+ #include < ATen/core/Tensor.h>
56#include < ATen/native/mkldnn/Matmul.h>
67
78#if !AT_MKLDNN_ENABLED()
@@ -53,7 +54,7 @@ bool mkldnn_fp16_gemm(
5354 c10::Half *c, int64_t ldc) {
5455 return false ;
5556}
56- bool mkldnn_bf32_gemm (
57+ bool mkldnn_reduced_f32_gemm (
5758 TransposeType transa, TransposeType transb,
5859 int64_t m, int64_t n, int64_t k,
5960 float alpha,
@@ -85,6 +86,13 @@ void mkldnn_matmul_i8i8i32(
8586 TORCH_INTERNAL_ASSERT (false , __func__, " : ATen not compiled with MKLDNN support" );
8687}
8788
89+ bool use_mkldnn_tf32_matmul (
90+ const Tensor& mat1,
3044
91+ const Tensor& mat2,
92+ const Tensor& result) {
93+ return false ;
94+ }
95+
8896} // namespace at::native
8997
9098
@@ -107,6 +115,10 @@ static bool use_mkldnn_bf32_matmul() {
107115 return use_mkldnn_bf16_matmul () && at::globalContext ().float32Precision (" mkldnn" , " matmul" ) == " bf16" ;
108116}
109117
118+ static bool use_mkldnn_tf32_matmul () {
119+ return cpuinfo_has_x86_amx_fp16 () && at::globalContext ().float32Precision (" mkldnn" , " matmul" ) == " tf32" ;
120+ }
121+
110122// returns an ideep::tensor
111123// - dims: shape e.g: {M,N}
112124// - idtype: ideep data type e.g: (f32, bf16, f16)
@@ -144,7 +156,8 @@ mkldnn_gemm(
144156 bool bf16_usable = std::is_same_v<scalar_t , c10::BFloat16> && use_mkldnn_bf16_matmul ();
145157 bool fp16_usable = std::is_same_v<scalar_t , c10::Half> && use_mkldnn_fp16_matmul ();
146158 bool bf32_usable = std::is_same_v<scalar_t , float > && use_mkldnn_bf32_matmul ();
147- if ( !(bf16_usable || fp16_usable || bf32_usable) ||
159+ bool tf32_usable = std::is_same_v<scalar_t , float > && use_mkldnn_tf32_matmul ();
160+ if ( !(bf16_usable || fp16_usable || bf32_usable || tf32_usable) ||
148161 (m * n * k <= 16 * 16 * 16 ) || (alpha == 0 .0f )) {
149162 return false ;
150163 }
@@ -155,6 +168,7 @@ mkldnn_gemm(
155168 op_attr = ideep::attr_t::fuse_sum ();
156169 }
157170 if (bf32_usable) op_attr.set_fpmath_mode (dnnl_fpmath_mode_bf16); // bf32 path
171+ if (tf32_usable) op_attr.set_fpmath_mode (dnnl_fpmath_mode_tf32); // tf32 path
158172
159173 // NOTE: View as c-contiguous to avoid extra reordering in mkldnn
160174 // Use identity: C = AB <=> C^T = B^T A^T
@@ -281,7 +295,7 @@ bool mkldnn_fp16_gemm(
281295 return mkldnn_gemm<c10::Half>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
282296}
283297
284- bool mkldnn_bf32_gemm (
298+ bool mkldnn_reduced_f32_gemm (
285299 TransposeType transa, TransposeType transb,
286300 int64_t m, int64_t n, int64_t k,
287301 float alpha,
@@ -339,13 +353,15 @@ void mkldnn_matmul(
339353 auto mat2_unsqueezed = mat2.dim () == 1 ? mat2.unsqueeze (1 ) : mat2;
340354 auto result_unsqueezed = result.dim () == 1 ? result.unsqueeze (1 ) : result;
341355 bool bf32_usable = mat1.scalar_type () == at::kFloat && use_mkldnn_bf32_matmul ();
356+ bool tf32_usable = mat1.scalar_type () == at::kFloat && use_mkldnn_tf32_matmul ();
342357
343358 ideep::attr_t op_attr;
344359 // "addmm", "addbmm" "baddbmm" in pytorch allow bias to be 2-D or 3-D tensor
345360 // but mkldnn matmul primitive only support bias be 1-D tensors
346361 // to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over
347362 if (beta != 0 .0f ) op_attr = ideep::attr_t::fuse_sum ();
348363 if (bf32_usable) op_attr.set_fpmath_mode (dnnl_fpmath_mode_bf16); // bf32 path
364+ if (tf32_usable) op_attr.set_fpmath_mode (dnnl_fpmath_mode_tf32); // tf32 path
349365 // If alpha = 0, dose not need actually do gemm computation
350366 if (alpha == 0 )
351367 return ;
@@ -412,70 +428,56 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
412428 }
413429}
414430
415- bool use_mkldnn_bf16_matmul (
431+ template <typename T>
432+ bool use_mkldnn_typed_matmul (
416433 const Tensor& mat1,
417434 const Tensor& mat2,
418435 const Tensor& result) {
436+ bool dtype_check = false ;
437+ if constexpr (std::is_same_v<T, c10::BFloat16>) {
419438#if defined(__aarch64__)
420- if (mkldnn_bf16_device_check_arm ()) {
421- // onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1
422- // so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well
423- return (
424- use_mkldnn_bf16_matmul () &&
425- (mat1.scalar_type () == mat2.scalar_type ()) && (!result.defined () || (mat1.scalar_type () == result.scalar_type ())) &&
426- ((mat1.scalar_type () == kFloat ) || (mat1.scalar_type () == kBFloat16 )) &&
427- mat1.numel () != 0 &&
428- mat2.numel () != 0 &&
429- checksize (mat1, mat2));
430- } else
439+ if (mkldnn_bf16_device_check_arm ()) {
440+ // onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
441+ // Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
442+ // inputs, allow it for float as well
443+ dtype_check = use_mkldnn_bf16_matmul () &&
444+ ((mat1.scalar_type () == kFloat ) || (mat1.scalar_type () == kBFloat16 ));
445+ }
446+ #else
447+ dtype_check = dtype_check && use_mkldnn_bf16_matmul () &&
448+ (mat1.scalar_type () == kBFloat16 );
431449#endif
432- {
433- return (
434- use_mkldnn_bf16_matmul () &&
435- mat1.scalar_type () == kBFloat16 &&
436- mat2.scalar_type () == kBFloat16 &&
437- (!result.defined () || result.scalar_type () == kBFloat16 ) &&
438- mat1.numel () != 0 &&
439- mat2.numel () != 0 &&
440- checksize (mat1, mat2));
450+ } else if constexpr (std::is_same_v<T, c10::Half>) {
451+ dtype_check = dtype_check && use_mkldnn_fp16_matmul () &&
452+ (mat1.scalar_type () == kHalf );
453+ } else if constexpr (std::is_same_v<T, float >) {
454+ dtype_check = dtype_check &&
455+ (use_mkldnn_bf32_matmul () || use_mkldnn_tf32_matmul ()) &&
456+ (mat1.scalar_type () == kFloat );
441457 }
442- }
443-
444- bool use_mkldnn_fp16_matmul (
445- const Tensor& mat1,
446- const Tensor& mat2,
447- const Tensor& result) {
448-
449- return (
450- use_mkldnn_fp16_matmul () &&
451- mat1.scalar_type () == kHalf &&
452- mat2.scalar_type () == kHalf &&
453- (!result.defined () || result.scalar_type () == kHalf ) &&
454- mat1.numel () != 0 &&
455- mat2.numel () != 0 &&
456- checksize (mat1, mat2));
457- }
458-
459- bool use_mkldnn_bf32_matmul (
460- const Tensor& mat1,
461- const Tensor& mat2,
462- const Tensor& result) {
463-
464- return (
465- use_mkldnn_bf32_matmul () &&
466- mat1.scalar_type () == kFloat &&
467- mat2.scalar_type () == kFloat &&
468- (!result.defined () || result.scalar_type () == kFloat ) &&
469- mat1.numel () != 0 &&
470- mat2.numel () != 0 &&
471- checksize (mat1, mat2));
458+ if (!dtype_check) {
459+ return false ;
460+ }
461+ bool size_check =
462+ mat1.numel () != 0 && mat2.numel () != 0 && checksize (mat1, mat2);
463+ dtype_check = (mat1.scalar_type () == mat2.scalar_type ()) &&
464+ (!result.defined () || result.scalar_type () == mat1.scalar_type ());
465+ return dtype_check && size_check;
472466}
473467
474468bool use_mkldnn_matmul (
475469 const Tensor& mat1,
476470 const Tensor& mat2,
477471 const Tensor& result) {
478- return (use_mkldnn_bf16_matmul (mat1, mat2, result) || use_mkldnn_fp16_matmul (mat1, mat2, result) || use_mkldnn_bf32_matmul (mat1, mat2, result));
472+ auto mat1_type = mat1.scalar_type ();
473+ if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat ) {
474+ return false ;
475+ }
476+ AT_DISPATCH_FLOATING_TYPES_AND2 (
477+ kBFloat16 , kHalf , mat1.scalar_type (), " use_mkldnn_matmul" , [&] {
478+ return use_mkldnn_typed_matmul<scalar_t >(mat1, mat2, result);
479+ });
480+ return false ;
479481}
480482
481483static void _mkldnn_matmul_i8i8i32_with_primitive (
0 commit comments