8000 [inductor] enable bf32 for mkldnn linear pointwise/binary in inductor · pytorch/pytorch@e8c5477 · GitHub
[go: up one dir, main page]

Skip to content

Commit e8c5477

Browse files
zhuhaozheyanbing-j
authored andcommitted
[inductor] enable bf32 for mkldnn linear pointwise/binary in inductor
ghstack-source-id: 649bdf8 Pull Request resolved: #127294
1 parent f242526 commit e8c5477

File tree

3 files changed

+68
-13
lines changed

3 files changed

+68
-13
lines changed

aten/src/ATen/native/mkldnn/Linear.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
5656

5757
namespace at::native {
5858

59+
static bool use_mkldnn_bf32_linear() {
60+
return at::globalContext().float32Precision("mkldnn", "matmul") == "bf16";
61+
}
62+
5963
Tensor mkldnn_linear(
6064
const Tensor& self,
6165
const Tensor& weight_t, const std::optional<Tensor>& bias_opt) {
@@ -231,7 +235,9 @@ Tensor mkldnn_linear_pointwise(
231235
it != fusion_unary_attr_map().end(), "Fusion behavior undefined.");
232236
op_attr = it->second(scalars, algorithm);
233237
}
234-
238+
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
239+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
240+
}
235241
if (mkldnn_bias.has_value()) {
236242
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
237243
mkldnn_input,
@@ -318,6 +324,10 @@ Tensor mkldnn_linear_pointwise_binary(
318324
auto other_desc = mkldnn_other.get_desc();
319325
auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc);
320326

327+
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
328+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
329+
}
330+
321331
if (mkldnn_bias.has_value()) {
322332
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
323333
mkldnn_input,

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def test_conv2d_unary_cpu(self):
326326
def test_conv3d_unary_cpu(self):
327327
self._test_conv_unary_cpu_base(dim=5)
328328

329+
@bf32_on_and_off()
329330
def test_linear_unary(self):
330331
class M(torch.nn.Module):
331332
def __init__(
@@ -354,6 +355,8 @@ def forward(self, x):
354355
dtypes.append(torch.bfloat16)
355356
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
356357
dtypes.append(torch.float16)
358+
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
359+
dtypes.append(torch.float32)
357360
options = itertools.product(unary_list, [True, False], dtypes)
358361
for unary_fn, bias, dtype in options:
359362
metrics.reset()
@@ -364,7 +367,7 @@ def forward(self, x):
364367

365368
def matcher_check_fn():
366369
match_nodes = unary_list[unary_fn]
367-
if self._check_unary_is_decomposed(unary_fn):
370+
if dtype != torch.float32 and self._check_unary_is_decomposed(unary_fn):
368371
# Has extra dtype conversion nodes for autocast.
369372
match_nodes += 2
370373
self.assertEqual(
@@ -376,9 +379,15 @@ def matcher_check_fn():
376379
)
377380

378381
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
379-
# only generated 1 kernel for "to"
380-
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
382+
expected_kernel_count = 1
383+
if TEST_ACL:
384+
expected_kernel_count = 2
385+
elif dtype == torch.float32:
386+
expected_kernel_count = 0
387+
# only generated 1 kernel for "to_dtype"
388+
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
381389

390+
@bf32_on_and_off()
382391
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
383392
def test_linear_fp32(self):
384393
class M(torch.nn.Module):
@@ -796,6 +805,7 @@ def test_conv2d_binary_broadcast_shapes_cpu(self):
796805
def test_conv3d_binary_broadcast_shapes_cpu(self):
797806
self._test_conv_binary_broadcast_shapes_base(dim=5)
798807

808+
@bf32_on_and_off()
799809
def test_linear_binary(self):
800810
class M(torch.nn.Module):
801811
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
@@ -815,6 +825,8 @@ def forward(self, x, y):
815825
dtypes.append(torch.bfloat16)
816826
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
817827
dtypes.append(torch.float16)
828+
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
829+
dtypes.append(torch.float32)
818830
options = itertools.product(
819831
binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
820832
)
@@ -851,7 +863,13 @@ def matcher_check_fn():
851863
matcher_check_fn,
852864
check_autocast=dtype,
853865
)
854-
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
866+
expected_kernel_count = 1
867+
if TEST_ACL:
868+
expected_kernel_count = 2
869+
elif dtype == torch.float32:
870+
expected_kernel_count = 0
871+
# only generated 1 kernel for "to_dtype"
872+
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
855873

856874
def test_linear_binary_broadcast_shapes_cpu(self):
857875
class M(torch.nn.Module):
@@ -914,7 +932,13 @@ def matcher_check_fn():
914932
matcher_check_fn,
915933
check_autocast=dtype,
916934
)
917-
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
935+
expected_kernel_count = 1
936+
if TEST_ACL:
937+
expected_kernel_count = 2
938+
elif dtype == torch.float32:
939+
expected_kernel_count = 0
940+
# only generated 1 kernel for "to_dtype"
941+
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
918942

919943
@skipIfNoDynamoSupport
920944
@skipIfNoONEDNN
@@ -947,6 +971,7 @@ def matcher_check_fn():
947971

948972
self._test_common(mod, (x1, x2), matcher_check_fn)
949973

974+
@bf32_on_and_off()
950975
def test_multi_linear_share_same_input(self):
951976
# llama pattern.
952977
class M(torch.nn.Module):

torch/_inductor/fx_passes/mkldnn_fusion.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -935,10 +935,20 @@ def is_linear_add_bias(match):
935935
bias_meta = add_node.args[1].meta.get("val")
936936
if weight_meta is None or bias_meta is None:
937937
return False
938-
assert weight_meta.dtype in (
939-
torch.bfloat16,
940-
torch.float16,
938+
939+
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
940+
use_bf16_for_fp32_weight = (
941+
bf32_matmul_enabled and weight_meta.dtype == torch.float32
942+
)
943+
assert (
944+
weight_meta.dtype
945+
in (
946+
torch.bfloat16,
947+
torch.float16,
948+
)
949+
or use_bf16_for_fp32_weight
941950
)
951+
942952
if bias_meta.dtype != weight_meta.dtype:
943953
return False
944954
return (
@@ -1098,10 +1108,15 @@ def is_const_or_cat_by_const(weight):
10981108
torch.bfloat16,
10991109
torch.float16,
11001110
)
1111+
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
1112+
use_bf16_for_fp32_weight = (
1113+
bf32_matmul_enabled and weight_meta_value.dtype == torch.float32
1114+
)
1115+
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
11011116
# on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
11021117
# on aarch64, use mkldnn op for fp32 as well if acl is enabled
11031118
if (
1104-
not is_lp_weight
1119+
not compute_with_lp
11051120
and not mkldnn._is_mkldnn_acl_supported()
11061121
and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
11071122
):
@@ -1308,9 +1323,14 @@ def linear(match, *args, **kwargs):
13081323
torch.bfloat16,
13091324
torch.float16,
13101325
)
1326+
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
1327+
use_bf16_for_fp32_weight = (
1328+
bf32_matmul_enabled and weight_dtype == torch.float32
1329+
)
1330+
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
13111331
batch_size = input.meta.get("val").shape[0]
13121332
if has_free_symbols(batch_size):
1313-
assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), (
1333+
assert compute_with_lp or mkldnn._is_mkldnn_acl_supported(), (
13141334
f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
13151335
)
13161336
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
@@ -1328,7 +1348,7 @@ def linear(match, *args, **kwargs):
13281348
packed_weight_op = (
13291349
mkldnn._reorder_linear_weight
13301350
if (
1331-
is_lp_weight
1351+
compute_with_lp
13321352
or mkldnn._is_mkldnn_acl_supported()
13331353
or V.aot_compilation
13341354
)
@@ -1340,7 +1360,7 @@ def linear(match, *args, **kwargs):
13401360

13411361
packed_linear_inputs: tuple[Any, ...] = (input, packed_weight_node)
13421362
if (
1343-
is_lp_weight
1363+
compute_with_lp
13441364
or mkldnn._is_mkldnn_acl_supported()
13451365
or V.aot_compilation
13461366
):

0 commit comments

Comments
 (0)
0