8000 Merge test_transformers.py · pytorch/pytorch@6e0b7d2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6e0b7d2

Browse files
DDElepytorchmergebot
authored andcommitted
Merge test_transformers.py
1 parent 74e0cd1 commit 6e0b7d2

File tree

3 files changed

+171
-238
lines changed

3 files changed

+171
-238
lines changed

.lintrunner.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1251,7 +1251,6 @@ exclude_patterns = [
12511251
'test/test_testing.py',
12521252
'test/test_torch.py',
12531253
'test/test_transformers.py',
1254-
'test/xpu/test_transformers.py',
12551254
'test/test_type_promotion.py',
12561255
'test/test_unary_ufuncs.py',
12571256
'test/test_vulkan.py',

test/test_transformers.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3838,6 +3838,176 @@ def rand_nt(sequence_list, num_heads, head_dim):
38383838
}
38393839
)
38403840

3841+
class TestSDPAXpuOnly(NNTestCase):
3842+
""" Used to test XPU only functionality of scaled_dot_product_attention
3843+
Mostly migrate from TestSDPACudaOnly in test/test_transformers.py
3844+
3845+
Note that as SDPBackend.OVERRIDEABLE is not managed by sdpa_kernel so that
3846+
math ref has to be called explicitly via torch.ops.aten._scaled_dot_product_attention_math.
3847+
"""
3848+
3849+
@parametrize("type", ["dense"])
3850+
@parametrize("dropout", [0.0, 0.7])
3851+
@parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.half])
3852+
@skipIfTorchDynamo()
3853+
def test_fused_sdp_choice_xpu(self, device, type: str, dropout: float, dtype: torch.dtype):
3854+
# Migrate from test_fused_sdp_choice_cpu
3855+
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype)
3856+
size = SdpaShape(2, 8, 128, 64)
3857+
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
3858+
if dropout > 0.0 or dtype not in [torch.float32, torch.bfloat16, torch.float16]:
3859+
assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.MATH.value
3860+
else:
3861+
assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.OVERRIDEABLE.value
3862+
3863+
def test_fused_attention_different_dk_dv(self, device):
3864+
dtype = torch.bfloat16
3865+
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
3866+
batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64
3867+
q_shape = SdpaShape(batch, num_heads, 1, head_dim_k)
3868+
k_shape = SdpaShape(batch, num_heads, 2, head_dim_k)
3869+
v_shape = SdpaShape(batch, num_heads, 2, head_dim_v)
3870+
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
3871+
3872+
# test that we do not dispatch to onednn for an unsupported case
3873+
actual = F.scaled_dot_product_attention(
3874+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3875+
3876+
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
3877+
query.float(), key.float(), value.float(), attn_mask=None, dropout_p=0.0, is_causal=False)[0]
3878+
3879+
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
3880+
3881+
def test_onednn_attention_fail_d256(self, device):
3882+
# Test that onednn graph attention dispatching correctly bails out on d > 256
3883+
b, h = 1, 2
3884+
s_q, s_kv = 128, 128
3885+
d_qk, d_v = 512, 512
3886+
3887+
q = torch.randn(b, h, s_q, d_qk, device=device, dtype=torch.bfloat16)
3888+
k = torch.randn(b, h, s_kv, d_qk, device=device, dtype=torch.bfloat16)
3889+
v = torch.randn(b, h, s_kv, d_v, device=device, dtype=torch.bfloat16)
3890+
3891+
with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]):
3892+
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
3893+
_ = F.scaled_dot_product_attention(q, k, v)
3894+
3895+
@parametrize("type", ["dense"])
3896+
@parametrize("is_contiguous", [True, False])
3897+
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
3898+
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)
3899+
3900+
batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
3901+
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
3902+
3903+
# Test Packed
3904+
qkv = make_tensor(shape)
3905+
query, key, value = qkv.chunk(3, dim=-1)
3906+
3907+
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
3908+
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
3909+
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
3910+
3911+
if is_contiguous:
3912+
query = query.contiguous()
3913+
key = key.contiguous()
3914+
value = value.contiguous()
3915+
3916+
with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]):
3917+
actual = torch.nn.functional.scaled_dot_product_attention(
3918+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3919+
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
3920+
query.contiguous(), key.contiguous(), value.contiguous(), attn_mask=None, dropout_p=0.0, is_causal=False)[0]
3921+
3922+
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
3923+
3924+
@parametrize("fused_kernel", [SDPBackend.MATH, SDPBackend.OVERRIDEABLE])
3925+
@parametrize("dtype", [torch.half, torch.bfloat16, torch.float32])
3926+
@parametrize("batch_size,n_head,q_size,kv_size,head_dim", [
3927+
(2, 5, 9216, 9216, 64),
3928+
(2, 5, 9216, 77, 64),
3929+
(2, 10, 2304, 2304, 64),
3930+
(2, 10, 2304, 77, 64),
3931+
(2, 20, 576, 576, 64),
3932+
(2, 20, 576, 77, 64),
3933+
(2, 20, 144, 144, 64),
3934+
(2, 20, 144, 77, 64),
3935+
(1, 32, 1, 32, 128),
3936+
(4, 32, 1, 32, 128),
3937+
(1, 32, 32, 32, 128),
3938+
(4, 32, 32, 32, 128),
3939+
(1, 32, 2016, 2016, 128),
3940+
(4, 32, 2016, 2016, 128),
3941+
])
3942+
@parametrize("mask_type", ["float", "causal"])
3943+
@parametrize("train", [False])
3944+
def test_scaled_dot_product_fused_attention_mask_vs_math(
3945+
self,
3946+
device,
3947+
fused_kernel,
3948+
dtype,
3949+
batch_size,
3950+
q_size,
3951+
kv_size,
3952+
n_head,
3953+
head_dim,
3954+
mask_type,
3955+
train,
3956+
):
3957+
# Migrate from TestSDPACpuOnly
3958+
tol = Tolerances(1e-5, 5e-6)
3959+
if dtype is torch.bfloat16:
3960+
tol = Tolerances(5e-2, 5e-2)
3961+
if dtype is torch.float16:
3962+
tol = Tolerances(1e-2, 1e-2)
3963+
mask_shape = [batch_size, 1, 1, kv_size]
3964+
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
3965+
q_shape = SdpaShape(batch_size, n_head, q_size, head_dim)
3966+
kv_shape = SdpaShape(batch_size, n_head, kv_size, head_dim)
3967+
q = make_tensor(q_shape)
3968+
k = make_tensor(kv_shape)
3969+
v = make_tensor(kv_shape)
3970+
q2, k2, v2 = q.clone(), k.clone(), v.clone()
3971+
3972+
if train:
3973+
q.requires_grad_(True)
3974+
k.requires_grad_(True)
3975+
v.requires_grad_(True)
3976+
q2.requires_grad_(True)
3977+
k2.requires_grad_(True)
3978+
v2.requires_grad_(True)
3979+
3980+
# (B, nh, T, hs)
3981+
q = q.view(batch_size, q_size, n_head, head_dim).transpose(1, 2)
3982+
k = k.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
3983+
v = v.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
3984+
attn_mask = None
3985+
is_causal = False
3986+
if mask_type == "bool":
3987+
attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
3988+
elif mask_type == "float":
3989+
attn_mask = torch.randn(mask_shape, dtype=dtype, device=device)
3990+
elif mask_type == "causal":
3991+
is_causal = True
3992+
3993+
q2, k2, v2 = q2.float(), k2.float(), v2.float()
3994+
q2 = q2.view(batch_size, q_size, n_head, head_dim).transpose(1, 2)
3995+
k2 = k2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
3996+
v2 = v2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
3997+
attn_mask2 = attn_mask.float() if attn_mask is not None else None
3998+
3999+
if fused_kernel == SDPBackend.MATH:
4000+
actual = torch.ops.aten._scaled_dot_product_attention_math(
4001+
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)[0]
4002+
elif fused_kernel == SDPBackend.OVERRIDEABLE:
4003+
actual = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
4004+
q, k, v, attn_bias=attn_mask, dropout_p=0.0, is_causal=is_causal)[0]
4005+
4006+
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
4007+
q2, k2, v2, attn_mask=attn_mask2, dropout_p=0.0, is_causal=is_causal)[0]
4008+
4009+
self.assertEqual(actual.float(), math_ref, atol=tol.atol, rtol=tol.rtol)
4010+
38414011

38424012
class TestAttnBias(NNTestCase):
38434013

@@ -4014,6 +4184,7 @@ def test_is_causal_and_mask_fails(self, device):
40144184
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
40154185
instantiate_device_type_tests(TestSDPACpuOnly, globals(), only_for=("cpu"))
40164186
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
4187+
instantiate_device_type_tests(TestSDPAXpuOnly, globals(), only_for="xpu", allow_xpu=True)
40174188

40184189
if __name__ == '__main__':
40194190
run_tests()

0 commit comments

Comments
 (0)
0