10000 add ut for dispatch sdpa kernel to sbgemv in CPUBlas.cpp · pytorch/pytorch@c52849f · GitHub
[go: up one dir, main page]

Skip to content

Commit c52849f

Browse files
taoye9pytorchmergebot
authored andcommitted
add ut for dispatch sdpa kernel to sbgemv in CPUBlas.cpp
1 parent ba0a974 commit c52849f

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

test/test_linalg.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7770,6 +7770,27 @@ def test_fp16_mv_transposed_first_argument_arm_cpu(self, device, m, k):
77707770
finally:
77717771
torch._C._set_cpu_allow_fp16_reduced_precision_reduction(prev)
77727772

7773+
@onlyCPU
7774+
@dtypes(torch.bfloat16)
7775+
@parametrize("m", [32, 35, 36, 40, 64, 128])
7776+
@parametrize("k", [32, 35, 36, 40, 64, 128])
7777+
# NOTE: This is intended to cover sbgemv_ testcase in CPUBlas.cpp.
7778+
def test_lowprecision_gemv_cpu(self, device, dtype, m, k):
7779+
torch.manual_seed(1)
7780+
a = torch.rand((m, k), dtype=dtype, device=device)
7781+
b = torch.rand((k, 1), dtype=dtype, device=device)
7782+
7783+
ref = torch.mm(a.to(torch.float32), b.to(torch.float32))
7784+
res = torch.mm(a, b).to(torch.float32)
7785+
torch.testing.assert_close(res, ref, atol=1e-2, rtol=1e-2)
7786+
7787+
a = torch.rand((k, m), dtype=dtype, device=device)
7788+
b = torch.rand((k, 1), dtype=dtype, device=device)
7789+
7790+
ref = torch.mm(a.t().to(torch.float32), b.to(torch.float32))
7791+
res = torch.mm(a.t(), b).to(torch.float32)
7792+
torch.testing.assert_close(res, ref, atol=1e-2, rtol=1e-2)
7793+
77737794
@slowTest
77747795
@onlyNativeDeviceTypes
77757796
# bfloat16 doesn't have sufficient precision to pass this test

test/test_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2009,7 +2009,7 @@ def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: to
20092009
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION])
20102010
@parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16])
20112011
@parametrize("batch_size", [2, 12])
2012-
@parametrize("q_seq_len", [11, 514, 1030])
2012+
@parametrize("q_seq_len", [1, 11, 514, 1030])
20132013
@parametrize("kv_seq_len", [17, 514])
20142014
@parametrize("n_head", [1, 3])
20152015
@parametrize("head_dim", [8])

0 commit comments

Comments
 (0)
0