|
27 | 27 | flex_attention_supported_platform as supported_platform,
|
28 | 28 | instantiate_device_type_tests,
|
29 | 29 | )
|
30 |
| -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU |
| 30 | +from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU |
31 | 31 | from torch.utils._triton import has_triton
|
32 | 32 |
|
33 | 33 |
|
@@ -722,19 +722,19 @@ def run_test_with_call_paged_attention(
|
722 | 722 | @expectedFailure # tl.dot does not support embedding size less than 16
|
723 | 723 | @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
|
724 | 724 | @common_utils.parametrize("dtype", test_dtypes_fast)
|
725 |
| - def test_bw_decoding_fails(self, dtype): |
| 725 | + def test_bw_decoding_fails(self, device, dtype): |
726 | 726 | make_kv = functools.partial(
|
727 | 727 | torch.randn,
|
728 | 728 | (2, 2, 128, 4),
|
729 | 729 | dtype=dtype,
|
730 |
| - device=GPU_TYPE, |
| 730 | + device=device, |
731 | 731 | requires_grad=True,
|
732 | 732 | )
|
733 | 733 | make_q = functools.partial(
|
734 | 734 | torch.randn,
|
735 | 735 | (2, 2, 8, 4),
|
736 | 736 | dtype=dtype,
|
737 |
| - device=GPU_TYPE, |
| 737 | + device=device, |
738 | 738 | requires_grad=True,
|
739 | 739 | )
|
740 | 740 | q, k, v, backward_grad = make_q(), make_kv(), make_kv(), make_q()
|
@@ -1004,20 +1004,20 @@ def mask_mod(b, h, q, kv):
|
1004 | 1004 |
|
1005 | 1005 | @supported_platform
|
1006 | 1006 | @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
|
1007 |
| - def test_non_divisible_multi_token_offset_mask_with_captured_buffer(self): |
| 1007 | + def test_non_divisible_multi_token_offset_mask_with_captured_buffer(self, device): |
1008 | 1008 | KV_S = S - 3
|
1009 | 1009 | Q_S = 3
|
1010 |
| - offset_kv = torch.randn(KV_S, device=GPU_TYPE, dtype=torch.bfloat16) |
1011 |
| - offset_q = torch.randn(Q_S, device=GPU_TYPE, dtype=torch.bfloat16) |
1012 |
| - offset_tensor = torch.tensor(S // 2 - 3, device=GPU_TYPE, dtype=torch.int32) |
| 1010 | + offset_kv = torch.randn(KV_S, device=device, dtype=torch.bfloat16) |
| 1011 | + offset_q = torch.randn(Q_S, device=device, dtype=torch.bfloat16) |
| 1012 | + offset_tensor = torch.tensor(S // 2 - 3, device=device, dtype=torch.int32) |
1013 | 1013 |
|
1014 | 1014 | def score_mod(score, b, h, q, kv):
|
1015 | 1015 | return score + offset_kv[kv] + offset_q[q]
|
1016 | 1016 |
|
1017 | 1017 | def mask_mod(b, h, q, kv):
|
1018 | 1018 | return kv >= q + offset_tensor
|
1019 | 1019 |
|
1020 |
| - block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S, device=self.device) |
| 1020 | + block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S, device=device) |
1021 | 1021 | self.run_test(Q_S=Q_S, KV_S=KV_S, block_mask=block_mask, score_mod=score_mod)
|
1022 | 1022 |
|
1023 | 1023 | @supported_platform
|
@@ -1682,19 +1682,19 @@ def mask_mod(b, h, q, kv):
|
1682 | 1682 | @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
|
1683 | 1683 | @common_utils.parametrize("dtype", test_dtypes)
|
1684 | 1684 | @common_utils.parametrize("score_mod", [_identity, _causal])
|
1685 |
| - def test_logsumexp_correctness(self, dtype, score_mod): |
| 1685 | + def test_logsumexp_correctness(self, device, dtype, score_mod): |
1686 | 1686 | make_kv = functools.partial(
|
1687 | 1687 | torch.randn,
|
1688 | 1688 | (B, Hkv, S, D),
|
1689 | 1689 | dtype=dtype,
|
1690 |
| - device=GPU_TYPE, |
| 1690 | + device=device, |
1691 | 1691 | requires_grad=True,
|
1692 | 1692 | )
|
1693 | 1693 | make_q = functools.partial(
|
1694 | 1694 | torch.randn,
|
1695 | 1695 | (B, Hkv, Hq // Hkv, D),
|
1696 | 1696 | dtype=dtype,
|
1697 |
| - device=GPU_TYPE, |
| 1697 | + device=device, |
1698 | 1698 | requires_grad=True,
|
1699 | 1699 | )
|
1700 | 1700 | q, k, v = make_q(), make_kv(), make_kv()
|
@@ -1734,19 +1734,19 @@ def eager_sdpa_hop(q, k, v, score_mod):
|
1734 | 1734 |
|
1735 | 1735 | @supported_platform
|
1736 | 1736 | @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
|
1737 |
| - def test_logsumexp_only_return(self): |
| 1737 | + def test_logsumexp_only_return(self, device): |
1738 | 1738 | make_q = functools.partial(
|
1739 | 1739 | torch.randn,
|
1740 | 1740 | (B, Hkv, Hq // Hkv, D),
|
1741 | 1741 | dtype=torch.float32,
|
1742 |
| - device=GPU_TYPE, |
| 1742 | + device=device, |
1743 | 1743 | requires_grad=True,
|
1744 | 1744 | )
|
1745 | 1745 | make_kv = functools.partial(
|
1746 | 1746 | torch.randn,
|
1747 | 1747 | (B, Hkv, S, D),
|
1748 | 1748 | dtype=torch.float32,
|
1749 |
| - device=GPU_TYPE, |
| 1749 | + device=device, |
1750 | 1750 | requires_grad=True,
|
1751 | 1751 | )
|
1752 | 1752 |
|
@@ -1998,7 +1998,7 @@ def causal_mask(b, h, q, kv):
|
1998 | 1998 | self._check_equal(golden_outs, ref_outs, paged_out, fudge_factor, "Out")
|
1999 | 1999 |
|
2000 | 2000 |
|
2001 |
| -instantiate_device_type_tests(TestFlexDecoding, globals(), only_for=test_device) |
| 2001 | +instantiate_device_type_tests(TestFlexDecoding, globals(), only_for=test_device, allow_xpu=True) |
2002 | 2002 |
|
2003 | 2003 | if __name__ == "__main__":
|
2004 | 2004 | from torch._inductor.test_case import run_tests
|
|
0 commit comments