@@ -323,6 +323,7 @@ def test_conv2d_unary_cpu(self):
323
323
def test_conv3d_unary_cpu (self ):
324
324
self ._test_conv_unary_cpu_base (dim = 5 )
325
325
326
+ @bf32_on_and_off ()
326
327
def test_linear_unary (self ):
327
328
class M (torch .nn .Module ):
328
329
def __init__ (
@@ -351,6 +352,8 @@ def forward(self, x):
351
352
dtypes .append (torch .bfloat16 )
352
353
if torch .ops .mkldnn ._is_mkldnn_fp16_supported ():
353
354
dtypes .append (torch .float16 )
355
+ if torch .backends .mkldnn .matmul .fp32_precision == "bf16" :
356
+ dtypes .append (torch .float32 )
354
357
options = itertools .product (unary_list , [True , False ], dtypes )
355
358
for unary_fn , bias , dtype in options :
356
359
metrics .reset ()
@@ -361,7 +364,7 @@ def forward(self, x):
361
364
362
365
def matcher_check_fn ():
363
366
match_nodes = unary_list [unary_fn ]
364
- if self ._check_unary_is_decomposed (unary_fn ):
367
+ if dtype != torch . float32 and self ._check_unary_is_decomposed (unary_fn ):
365
368
# Has extra dtype conversion nodes for autocast.
366
369
match_nodes += 2
367
370
self .assertEqual (
@@ -373,9 +376,15 @@ def matcher_check_fn():
373
376
)
374
377
375
378
self ._test_common (mod , (v ,), matcher_check_fn , check_autocast = dtype )
376
- # only generated 1 kernel for "to"
377
- self .assertEqual (metrics .generated_kernel_count , 2 if TEST_ACL else 1 )
379
+ expected_kernel_count = 1
380
+ if TEST_ACL :
381
+ expected_kernel_count = 2
382
+ elif dtype == torch .float32 :
383
+ expected_kernel_count = 0
384
+ # only generated 1 kernel for "to_dtype"
385
+ self .assertEqual (metrics .generated_kernel_count , expected_kernel_count )
378
386
387
+ @bf32_on_and_off ()
379
388
@unittest .skipIf (not TEST_MKL , "Test requires MKL" )
380
389
def test_linear_fp32 (self ):
381
390
class M (torch .nn .Module ):
@@ -793,6 +802,7 @@ def test_conv2d_binary_broadcast_shapes_cpu(self):
793
802
def test_conv3d_binary_broadcast_shapes_cpu (self ):
794
803
self ._test_conv_binary_broadcast_shapes_base (dim = 5 )
795
804
805
+ @bf32_on_and_off ()
796
806
def test_linear_binary (self ):
797
807
class M (torch .nn .Module ):
798
808
def __init__ (self , binary_fn , in_channels , out_channels , bias , ** kwargs ):
@@ -812,6 +822,8 @@ def forward(self, x, y):
812
822
dtypes .append (torch .bfloat16 )
813
823
if torch .ops .mkldnn ._is_mkldnn_fp16_supported ():
814
824
dtypes .append (torch .float16 )
825
+ if torch .backends .mkldnn .matmul .fp32_precision == "bf16" :
826
+ dtypes .append (torch .float32 )
815
827
options = itertools .product (
816
828
binary_list , [[2 , 3 , 10 ], [2 , 10 ]], [True , False ], dtypes
817
829
)
@@ -848,7 +860,13 @@ def matcher_check_fn():
848
860
matcher_check_fn ,
849
861
check_autocast = dtype ,
850
862
)
851
- self .assertEqual (metrics .generated_kernel_count , 2 if TEST_ACL else 1 )
863
+ expected_kernel_count = 1
864
+ if TEST_ACL :
865
+ expected_kernel_count = 2
866
+ elif dtype == torch .float32 :
867
+ expected_kernel_count = 0
868
+ # only generated 1 kernel for "to_dtype"
869
+ self .assertEqual (metrics .generated_kernel_count , expected_kernel_count )
852
870
853
871
def test_linear_binary_broadcast_shapes_cpu (self ):
854
872
class M (torch .nn .Module ):
@@ -911,7 +929,13 @@ def matcher_check_fn():
911
929
matcher_check_fn ,
912
930
check_autocast = dtype ,
913
931
)
914
- self .assertEqual (metrics .generated_kernel_count , 2 if TEST_ACL else 1 )
932
+ expected_kernel_count = 1
933
+ if TEST_ACL :
934
+ expected_kernel_count = 2
935
+ elif dtype == torch .float32 :
936
+ expected_kernel_count = 0
937
+ # only generated 1 kernel for "to_dtype"
938
+ self .assertEqual (metrics .generated_kernel_count , expected_kernel_count )
915
939
916
940
@skipIfNoDynamoSupport
917
941
@skipIfNoONEDNN
@@ -944,6 +968,7 @@ def matcher_check_fn():
944
968
945
969
self ._test_common (mod , (x1 , x2 ), matcher_check_fn )
946
970
971
+ @bf32_on_and_off ()
947
972
def test_multi_linear_share_same_input (self ):
948
973
# llama pattern.
949
974
class M (torch .nn .Module ):
0 commit comments