@@ -7770,6 +7770,27 @@ def test_fp16_mv_transposed_first_argument_arm_cpu(self, device, m, k):
7770
7770
finally :
7771
7771
torch ._C ._set_cpu_allow_fp16_reduced_precision_reduction (prev )
7772
7772
7773
+ @onlyCPU
7774
+ @dtypes (torch .bfloat16 )
7775
+ @parametrize ("m" , [32 , 35 , 36 , 40 , 64 , 128 ])
7776
+ @parametrize ("k" , [32 , 35 , 36 , 40 , 64 , 128 ])
7777
+ # NOTE: This is intended to cover sbgemv_ testcase in CPUBlas.cpp.
7778
+ def test_lowprecision_gemv_cpu (self , device , dtype , m , k ):
7779
+ torch .manual_seed (1 )
7780
+ a = torch .rand ((m , k ), dtype = dtype , device = device )
7781
+ b = torch .rand ((k , 1 ), dtype = dtype , device = device )
7782
+
7783
+ ref = torch .mm (a .to (torch .float32 ), b .to (torch .float32 ))
7784
+ res = torch .mm (a , b ).to (torch .float32 )
7785
+ torch .testing .assert_close (res , ref , atol = 1e-2 , rtol = 1e-2 )
7786
+
7787
+ a = torch .rand ((k , m ), dtype = dtype , device = device )
7788
+ b = torch .rand ((k , 1 ), dtype = dtype , device = device )
7789
+
7790
+ ref = torch .mm (a .t ().to (torch .float32 ), b .to (torch .float32 ))
7791
+ res = torch .mm (a .t (), b ).to (torch .float32 )
7792
+ torch .testing .assert_close (res , ref , atol = 1e-2 , rtol = 1e-2 )
7793
+
7773
7794
@slowTest
7774
7795
@onlyNativeDeviceTypes
7775
7796
# bfloat16 doesn't have sufficient precision to pass this test
0 commit comments