8000 [inductor] enable bf32 for mkldnn linear pointwise/binary in inductor · pytorch/pytorch@595d83b · GitHub
[go: up one dir, main page]

Skip to content

Commit 595d83b

Browse files
zhuhaozheyanbing-j
authored andcommitted
[inductor] enable bf32 for mkldnn linear pointwise/binary in inductor
ghstack-source-id: 56c90f0 Pull Request resolved: #127294
1 parent bf32ec5 commit 595d83b

File tree

3 files changed

+68
-13
lines changed

3 files changed

+68
-13
lines changed

aten/src/ATen/native/mkldnn/Linear.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
5656

5757
namespace at::native {
5858

59+
static bool use_mkldnn_bf32_linear() {
60+
return at::globalContext().float32Precision("mkldnn", "matmul") == "bf16";
61+
}
62+
5963
Tensor mkldnn_linear(
6064
const Tensor& self,
6165
const Tensor& weight_t, const std::optional<Tensor>& bias_opt) {
@@ -231,7 +235,9 @@ Tensor mkldnn_linear_pointwise(
231235
it != fusion_unary_attr_map().end(), "Fusion behavior undefined.");
232236
op_attr = it->second(scalars, algorithm);
233237
}
234-
238+
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
239+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
240+
}
235241
if (mkldnn_bias.has_value()) {
236242
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
237243
mkldnn_input,
@@ -318,6 +324,10 @@ Tensor mkldnn_linear_pointwise_binary(
318324
auto other_desc = mkldnn_other.get_desc();
319325
auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc);
320326

327+
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
328+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
329+
}
330+
321331
if (mkldnn_bias.has_value()) {
322332
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
323333
mkldnn_input,

test/inductor/test_mkldnn_pattern_matcher.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def test_conv2d_unary_cpu(self):
323323
def test_conv3d_unary_cpu(self):
324324
self._test_conv_unary_cpu_base(dim=5)
325325

326+
@bf32_on_and_off()
326327
def test_linear_unary(self):
327328
class M(torch.nn.Module):
328329
def __init__(
@@ -351,6 +352,8 @@ def forward(self, x):
351352
dtypes.append(torch.bfloat16)
352353
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
353354
dtypes.append(torch.float16)
355+
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
356+
dtypes.append(torch.float32)
354357
options = itertools.product(unary_list, [True, False], dtypes)
355358
for unary_fn, bias, dtype in options:
356359
metrics.reset()
@@ -361,7 +364,7 @@ def forward(self, x):
361364

362365
def matcher_check_fn():
363366
match_nodes = unary_list[unary_fn]
364-
if self._check_unary_is_decomposed(unary_fn):
367+
if dtype != torch.float32 and self._check_unary_is_decomposed(unary_fn):
365368
# Has extra dtype conversion nodes for autocast.
366369
match_nodes += 2
367370
self.assertEqual(
@@ -373,9 +376,15 @@ def matcher_check_fn():
373376
)
374377

375378
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
376-
# only generated 1 kernel for "to"
377-
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
379+
expected_kernel_count = 1
380+
if TEST_ACL:
381+
expected_kernel_count = 2
382+
elif dtype == torch.float32:
383+
expected_kernel_count = 0
384+
# only generated 1 kernel for "to_dtype"
385+
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
378386

387+
@bf32_on_and_off()
379388
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
380389
def test_linear_fp32(self):
381390
class M(torch.nn.Module):
@@ -793,6 +802,7 @@ def test_conv2d_binary_broadcast_shapes_cpu(self):
793802
def test_conv3d_binary_broadcast_shapes_cpu(self):
794803
self._test_conv_binary_broadcast_shapes_base(dim=5)
795804

805+
@bf32_on_and_off()
796806
def test_linear_binary(self):
797807
class M(torch.nn.Module):
798808
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
@@ -812,6 +822,8 @@ def forward(self, x, y):
812822
dtypes.append(torch.bfloat16)
813823
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
814824
dtypes.append(torch.float16)
825+
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
826+
dtypes.append(torch.float32)
815827
options = itertools.product(
816828
binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
817829
)
@@ -848,7 +860,13 @@ def matcher_check_fn():
848860
matcher_check_fn,
849861
check_autocast=dtype,
850862
)
851-
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
863+
expected_kernel_count = 1
864+
if TEST_ACL:
865+
expected_kernel_count = 2
866+
elif dtype == torch.float32:
867+
expected_kernel_count = 0
868+
# only generated 1 kernel for "to_dtype"
869+
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
852870

853871
def test_linear_binary_broadcast_shapes_cpu(self):
854872
class M(torch.nn.Module):
@@ -911,7 +929,13 @@ def matcher_check_fn():
911929
matcher_check_fn,
912930
check_autocast=dtype,
913931
)
914-
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
932+
expected_kernel_count = 1
933+
if TEST_ACL:
934+
expected_kernel_count = 2
935+
elif dtype == torch.float32:
936+
expected_kernel_count = 0
937+
# only generated 1 kernel for "to_dtype"
938+
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
915939

916940
@skipIfNoDynamoSupport
917941
@skipIfNoONEDNN
@@ -944,6 +968,7 @@ def matcher_check_fn():
944968

945969
self._test_common(mod, (x1, x2), matcher_check_fn)
946970

971+
@bf32_on_and_off()
947972
def test_multi_linear_share_same_input(self):
948973
# llama pattern.
949974
class M(torch.nn.Module):

torch/_inductor/fx_passes/mkldnn_fusion.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -935,10 +935,20 @@ def is_linear_add_bias(match):
935935
bias_meta = add_node.args[1].meta.get("val")
936936
if weight_meta is None or bias_meta is None:
937937
return False
938-
assert weight_meta.dtype in (
939-
torch.bfloat16,
940-
torch.float16,
938+
939+
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
940+
use_bf16_for_fp32_weight = (
941+
bf32_matmul_enabled and weight_meta.dtype == torch.float32
942+
)
943+
assert (
944+
weight_meta.dtype
945+
in (
946+
torch.bfloat16,
947+
torch.float16,
948+
)
949+
or use_bf16_for_fp32_weight
941950
)
951+
942952
if bias_meta.dtype != weight_meta.dtype:
943953
return False
944954
return (
@@ -1098,10 +1108,15 @@ def is_const_or_cat_by_const(weight):
10981108
torch.bfloat16,
10991109
torch.float16,
11001110
)
1111+
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
1112+
use_bf16_for_fp32_weight = (
1113+
bf32_matmul_enabled and weight_meta_value.dtype == torch.float32
1114+
)
1115+
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
11011116
# on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
11021117
# on aarch64, use mkldnn op for fp32 as well if acl is enabled
11031118
if (
1104-
not is_lp_weight
1119+
not compute_with_lp
11051120
and not mkldnn._is_mkldnn_acl_supported()
11061121
and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
11071122
):
@@ -1308,9 +1323,14 @@ def linear(match, *args, **kwargs):
13081323
torch.bfloat16,
13091324
torch.float16,
13101325
)
1326+
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
1327+
use_bf16_for_fp32_weight = (
1328+
bf32_matmul_enabled and weight_dtype == torch.float32
1329+
)
1330+
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
13111331
batch_size = input.meta.get("val").shape[0]
13121332
if has_free_symbols(batch_size):
1313-
assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), (
1333+
assert compute_with_lp or mkldnn._is_mkldnn_acl_supported(), (
13141334
f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
13151335
)
13161336
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
@@ -1328,7 +1348,7 @@ def linear(match, *args, **kwargs):
13281348
packed_weight_op = (
13291349
mkldnn._reorder_linear_weight
13301350
if (
1331-
is_lp_weight
1351+
compute_with_lp
13321352
or mkldnn._is_mkldnn_acl_supported()
13331353
or V.aot_compilation
13341354
)
@@ -1340,7 +1360,7 @@ def linear(match, *args, **kwargs):
13401360

13411361
packed_linear_inputs: tuple[Any, ...] = (input, packed_weight_node)
13421362
if (
1343-
is_lp_weight
1363+
compute_with_lp
13441364
or mkldnn._is_mkldnn_acl_supported()
13451365
or V.aot_compilation
13461366
):

0 commit comments

Comments
 (0)
0