8000 Merge test_transformers.py · pytorch/pytorch@06da205 · GitHub
[go: up one dir, main page]

Skip to content

Commit 06da205

Browse files
committed
Merge test_transformers.py
1 parent 711b683 commit 06da205

File tree

3 files changed

+171
-238
lines changed

3 files changed

+171
-238
lines changed

.lintrunner.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1251,7 +1251,6 @@ exclude_patterns = [
12511251
'test/test_testing.py',
12521252
'test/test_torch.py',
12531253
'test/test_transformers.py',
1254-
'test/xpu/test_transformers.py',
12551254
'test/test_type_promotion.py',
12561255
'test/test_unary_ufuncs.py',
12571256
'test/test_vulkan.py',

test/test_transformers.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3841,6 +3841,176 @@ def rand_nt(sequence_list, num_heads, head_dim):
38413841
}
38423842
)
38433843

3844+
class TestSDPAXpuOnly(NNTestCase):
3845+
""" Used to test XPU only functionality of scaled_dot_product_attention
3846+
Mostly migrate from TestSDPACudaOnly in test/test_transformers.py
3847+
3848+
Note that as SDPBackend.OVERRIDEABLE is not managed by sdpa_kernel so that
3849+
math ref has to be called explicitly via torch.ops.aten._scaled_dot_product_attention_math.
3850+
"""
3851+
3852+
@parametrize("type", ["dense"])
3853+
@parametrize("dropout", [0.0, 0.7])
3854+
@parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.half])
3855+
@skipIfTorchDynamo()
3856+
def test_fused_sdp_choice_xpu(self, device, type: str, dropout: float, dtype: torch.dtype):
3857+
# Migrate from test_fused_sdp_choice_cpu
3858+
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype)
3859+
size = SdpaShape(2, 8, 128, 64)
3860+
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
3861+
if dropout > 0.0 or dtype not in [torch.float32, torch.bfloat16, torch.float16]:
3862+
assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.MATH.value
3863+
else:
3864+
assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.OVERRIDEABLE.value
3865+
3866+
def test_fused_attention_different_dk_dv(self, device):
3867+
dtype = torch.bfloat16
3868+
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
3869+
batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64
3870+
q_shape = SdpaShape(batch, num_heads, 1, head_dim_k)
3871+
k_shape = SdpaShape(batch, num_heads, 2, head_dim_k)
3872+
v_shape = SdpaShape(batch, num_heads, 2, head_dim_v)
3873+
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
3874+
3875+
# test that we do not dispatch to onednn for an unsupported case
3876+
actual = F.scaled_dot_product_attention(
3877+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3878+
3879+
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
3880+
query.float(), key.float(), value.float(), attn_mask=None, dropout_p=0.0, is_causal=False)[0]
3881+
3882+
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
3883+
3884+
def test_onednn_attention_fail_d256(self, device):
3885+
# Test that onednn graph attention dispatching correctly bails out on d > 256
3886+
b, h = 1, 2
3887+
s_q, s_kv = 128, 128
3888+
d_qk, d_v = 512, 512
3889+
3890+
q = torch.randn(b, h, s_q, d_qk, device=device, dtype=torch.bfloat16)
3891+
k = torch.randn(b, h, s_kv, d_qk, device=device, dtype=torch.bfloat16)
3892+
v = torch.randn(b, h, s_kv, d_v, device=device, dtype=torch.bfloat16)
3893+
3894+
with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]):
3895+
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
3896+
_ = F.scaled_dot_product_attention(q, k, v)
3897+
3898+
@parametrize("type", ["dense"])
3899+
@parametrize("is_contiguous", [True, False])
3900+
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
3901+
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)
3902+
3903+
batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
3904+
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
3905+
3906+
# Test Packed
3907+
qkv = make_tensor(shape)
3908+
query, key, value = qkv.chunk(3, dim=-1)
3909+
3910+
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
3911+
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
3912+
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
3913+
3914+
if is_contiguous:
3915+
query = query.contiguous()
3916+
key = key.contiguous()
3917+
value = value.contiguous()
3918+
3919+
with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]):
3920+
actual = torch.nn.functional.scaled_dot_product_attention(
3921+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3922+
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
3923+
query.contiguous(), key.contiguous(), value.contiguous(), attn_mask=None, dropout_p=0.0, is_causal=False)[0]
3924+
3925+
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
3926+
3927+
@parametrize("fused_kernel", [SDPBackend.MATH, SDPBackend.OVERRIDEABLE])
3928+
@parametrize("dtype", [torch.half, torch.bfloat16, torch.float32])
3929+
@parametrize("batch_size,n_head,q_size,kv_size,head_dim", [
3930+
(2, 5, 9216, 9216, 64),
3931+
(2, 5, 9216, 77, 64),
3932+
(2, 10, 2304, 2304, 64),
3933+
(2, 10, 2304, 77, 64),
3934+
(2, 20, 576, 576, 64),
3935+
(2, 20, 576, 77, 64),
3936+
(2, 20, 144, 144, 64),
3937+
(2, 20, 144, 77, 64),
3938+
(1, 32, 1, 32, 128),
3939+
(4, 32, 1, 32, 128),
3940+
(1, 32, 32, 32, 128),
3941+
(4, 32, 32, 32, 128),
3942+
(1, 32, 2016, 2016, 128),
3943+
(4, 32, 2016, 2016, 128),
3944+
])
3945+
@parametrize("mask_type", ["float", "causal"])
3946+
@parametrize("train", [False])
3947+
def test_scaled_dot_product_fused_attention_mask_vs_math(
3948+
self,
3949+
device,
3950+
fused_kernel,
3951+
dtype,
3952+
batch_size,
3953+
q_size,
3954+
kv_size,
3955+
n_head,
3956+
head_dim,
3957+
mask_type,
3958+
train,
3959+
):
3960+
# Migrate from TestSDPACpuOnly
3961+
tol = Tolerances(1e-5, 5e-6)
3962+
if dtype is torch.bfloat16:
3963+
tol = Tolerances(5e-2, 5e-2)
3964+
if dtype is torch.float16:
3965+
tol = Tolerances(1e-2, 1e-2)
3966+
mask_shape = [batch_size, 1, 1, kv_size]
3967+
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
3968+
q_shape = SdpaShape(batch_size, n_head, q_size, head_dim)
3969+
kv_shape = SdpaShape(batch_size, n_head, kv_size, head_dim)
3970+
q = make_tensor(q_shape)
3971+
k = make_tensor(kv_shape)
3972+
v = make_tensor(kv_shape)
3973+
q2, k2, v2 = q.clone(), k.clone(), v.clone()
3974+
3975+
if train:
3976+
q.requires_grad_(True)
3977+
k.requires_grad_(True)
3978+
v.requires_grad_(True)
3979+
q2.requires_grad_(True)
3980+
k2.requires_grad_(True)
3981+
v2.requires_grad_(True)
3982+
3983+
# (B, nh, T, hs)
3984+
q = q.view(batch_size, q_size, n_head, head_dim).transpose(1, 2)
3985+
k = k.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
3986+
v = v.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
3987+
attn_mask = None
3988+
is_causal = False
3989+
if mask_type == "bool":
3990+
attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
3991+
elif mask_type == "float":
3992+
attn_mask = torch.randn(mask_shape, dtype=dtype, device=device)
3993+
elif mask_type == "causal":
3994+
is_causal = True
3995+
3996+
q2, k2, v2 = q2.float(), k2.float(), v2.float()
3997+
q2 = q2.view(batch_size, q_size, n_head, head_dim).transpose(1, 2)
3998+
k2 = k2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
3999+
v2 = v2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
4000+
attn_mask2 = attn_mask.float() if attn_mask is not None else None
4001+
4002+
if fused_kernel == SDPBackend.MATH:
4003+
actual = torch.ops.aten._scaled_dot_product_attention_math(
4004+
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)[0]
4005+
elif fused_kernel == SDPBackend.OVERRIDEABLE:
4006+
actual = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
4007+
q, k, v, attn_bias=attn_mask, dropout_p=0.0, is_causal=is_causal)[0]
4008+
4009+
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
4010+
q2, k2, v2, attn_mask=attn_mask2, dropout_p=0.0, is_causal=is_causal)[0]
4011+
4012+
self.assertEqual(actual.float(), math_ref, atol=tol.atol, rtol=tol.rtol)
4013+
38444014

38454015
class TestAttnBias(NNTestCase):
38464016

@@ -4080,6 +4250,7 @@ def test_scaled_dot_product_fused_attention_overrideable_backward(self):
40804250
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
40814251
instantiate_device_type_tests(TestSDPACpuOnly, globals(), only_for=("cpu"))
40824252
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
4253+
instantiate_device_type_tests(TestSDPAXpuOnly, globals(), only_for="xpu", allow_xpu=True)
40834254

40844255
if __name__ == '__main__':
40854256
run_tests()

0 commit comments

Comments
 (0)
0