8000 [inductor] enable bf32 for mkldnn linear pointwise/binary in inductor · pytorch/pytorch@9f1da17 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9f1da17

Browse files
zhuhaozheyanbing-j
authored andcommitted
[inductor] enable bf32 for mkldnn linear pointwise/binary in inductor
ghstack-source-id: b81d02d Pull Request resolved: #127294
1 parent ed5f016 commit 9f1da17

File tree

3 files changed

+68
-13
lines changed

3 files changed

+68
-13
lines changed

aten/src/ATen/native/mkldnn/Linear.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
6868

6969
namespace at::native {
7070

71+
static bool use_mkldnn_bf32_linear() {
72+
return at::globalContext().float32Precision("mkldnn", "matmul") == "bf16";
73+
}
74+
7175
Tensor mkldnn_linear(
7276
const Tensor& self,
7377
const Tensor& weight_t, const std::optional<Tensor>& bias_opt) {
@@ -251,7 +255,9 @@ Tensor mkldnn_linear_pointwise(
251255
it != fusion_unary_attr_map().end(), "Fusion behavior undefined.");
252256
op_attr = it->second(scalars, algorithm);
253257
}
254-
258+
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
259+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
260+
}
255261
if (mkldnn_bias.has_value()) {
256262
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
257263
mkldnn_input,
@@ -341,6 +347,10 @@ Tensor mkldnn_linear_pointwise_binary(
341347
auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc);
342348
auto aprop_kind = ideep::prop_kind::forward_inference;
343349

350+
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
351+
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
352+
}
353+
344354
if (mkldnn_bias.has_value()) {
345355
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
346356
mkldnn_input,

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def test_conv3d_unary(self, device):
343343
self.device = device
344344
self._test_conv_unary_base(dim=5)
345345

346+
@bf32_on_and_off()
346347
def test_linear_unary(self, device):
347348
self.device = device
348349

@@ -373,6 +374,8 @@ def forward(self, x):
373374
dtypes.append(torch.bfloat16)
374375
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
375376
dtypes.append(torch.float16)
377+
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
378+
dtypes.append(torch.float32)
376379
options = itertools.product(unary_list, [True, False], dtypes)
377380
for unary_fn, bias, dtype in options:
378381
metrics.reset()
@@ -383,7 +386,7 @@ def forward(self, x):
383386

384387
def matcher_check_fn():
385388
match_nodes = unary_list[unary_fn]
386-
if self._check_unary_is_decomposed(unary_fn):
389+
if dtype != torch.float32 and self._check_unary_is_decomposed(unary_fn):
387390
# Has extra dtype conversion nodes for autocast.
388391
match_nodes += 2
389392
self.assertEqual(
@@ -395,9 +398,15 @@ def matcher_check_fn():
395398
)
396399

397400
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
398-
# only generated 1 kernel for "to"
399-
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
401+
expected_kernel_count = 1
402+
if TEST_ACL:
403+
expected_kernel_count = 2
404+
elif dtype == torch.float32:
405+
expected_kernel_count = 0
406+
# only generated 1 kernel for "to_dtype"
407+
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
400408

409+
@bf32_on_and_off()
401410
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
402411
def test_linear_fp32(self, device):
403412
self.device = device
@@ -825,6 +834,7 @@ def test_conv2d_binary_broadcast_shapes_cpu(self):
825834
def test_conv3d_binary_broadcast_shapes_cpu(self):
826835
self._test_conv_binary_broadcast_shapes_base(dim=5)
827836

837+
@bf32_on_and_off()
828838
def test_linear_binary(self, device):
829839
self.device = device
830840

@@ -846,6 +856,8 @@ def forward(self, x, y):
846856
dtypes.append(torch.bfloat16)
847857
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
848858
dtypes.append(torch.float16)
859+
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
860+
dtypes.append(torch.float32)
849861
options = itertools.product(
850862
binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
851863
)
@@ -882,7 +894,13 @@ def matcher_check_fn():
882894
matcher_check_fn,
883895
check_autocast=dtype,
884896
)
885-
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
897+
expected_kernel_count = 1
898+
if TEST_ACL:
899+
expected_kernel_count = 2
900+
elif dtype == torch.float32:
901+
expected_kernel_count = 0
902+
# only generated 1 kernel for "to_dtype"
903+
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
886904

887905
def test_linear_binary_broadcast_shapes_cpu(self):
888906
class M(torch.nn.Module):
@@ -945,7 +963,13 @@ def matcher_check_fn():
945963
matcher_check_fn,
946964
check_autocast=dtype,
947965
)
948-
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
966+
expected_kernel_count = 1
967+
if TEST_ACL:
968+
expected_kernel_count = 2
969+
elif dtype == torch.float32:
970+
expected_kernel_count = 0
971+
# only generated 1 kernel for "to_dtype"
972+
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
949973

950974
@skipIfNoDynamoSupport
951975
@skipIfNoONEDNN
@@ -978,6 +1002,7 @@ def matcher_check_fn():
9781002

9791003
self._test_common(mod, (x1, x2), matcher_check_fn)
9801004

1005+
@bf32_on_and_off()
9811006
def test_multi_linear_share_same_input(self, device):
9821007
self.device = device
9831008

torch/_inductor/fx_passes/mkldnn_fusion.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -935,10 +935,20 @@ def is_linear_add_bias(match):
935935
bias_meta = add_node.args[1].meta.get("val")
936936
if weight_meta is None or bias_meta is None:
937937
return False
938-
assert weight_meta.dtype in (
939-
torch.bfloat16,
940-
torch.float16,
938+
939+
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
940+
use_bf16_for_fp32_weight = (
941+
bf32_matmul_enabled and weight_meta.dtype == torch.float32
942+
)
943+
assert (
944+
weight_meta.dtype
945+
in (
946+
torch.bfloat16,
947+
torch.float16,
948+
)
949+
or use_bf16_for_fp32_weight
941950
)
951+
942952
if bias_meta.dtype != weight_meta.dtype:
943953
return False
944954
return (
@@ -1098,10 +1108,15 @@ def is_const_or_cat_by_const(weight):
10981108
torch.bfloat16,
10991109
torch.float16,
11001110
)
1111+
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
1112+
use_bf16_for_fp32_weight = (
1113+
bf32_matmul_enabled and weight_meta_value.dtype == torch.float32
1114+
)
1115+
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
11011116
# on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
11021117
# on aarch64, use mkldnn op for fp32 as well if acl is enabled
11031118
if (
1104-
not is_lp_weight
1119+
not compute_with_lp
11051120
and not mkldnn._is_mkldnn_acl_supported()
11061121
and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
11071122
):
@@ -1308,9 +1323,14 @@ def linear(match, *args, **kwargs):
13081323
torch.bfloat16,
13091324
torch.float16,
13101325
)
1326+
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
1327+
use_bf16_for_fp32_weight = (
1328+
bf32_matmul_enabled and weight_dtype == torch.float32
1329+
)
1330+
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
13111331
batch_size = input.meta.get("val").shape[0]
13121332
if has_free_symbols(batch_size):
1313-
assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), (
1333+
assert compute_with_lp or mkldnn._is_mkldnn_acl_supported(), (
13141334
f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
13151335
)
13161336
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
@@ -1328,7 +1348,7 @@ def linear(match, *args, **kwargs):
13281348
packed_weight_op = (
13291349
mkldnn._reorder_linear_weight
13301350
if (
1331-
is_lp_weight
1351+
compute_with_lp
13321352
or mkldnn._is_mkldnn_acl_supported()
13331353
or V.aot_compilation
13341354
)
@@ -1340,7 +1360,7 @@ def linear(match, *args, **kwargs):
13401360

13411361
packed_linear_inputs: tuple[Any, ...] = (input, packed_weight_node)
13421362
if (
1343-
is_lp_weight
1363+
compute_with_lp
13441364
or mkldnn._is_mkldnn_acl_supported()
13451365
or V.aot_compilation
13461366
):

0 commit comments

Comments
 (0)
0