8000 remove self.device & GPU_TYPE · pytorch/pytorch@f79f883 · GitHub
[go: up one dir, main page]

Skip to content

Commit f79f883

Browse files
committed
remove self.device & GPU_TYPE
1 parent fd14da6 commit f79f883

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

test/inductor/test_flex_decoding.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
flex_attention_supported_platform as supported_platform,
2828
instantiate_device_type_tests,
2929
)
30-
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU
30+
from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU
3131
from torch.utils._triton import has_triton
3232

3333

@@ -722,19 +722,19 @@ def run_test_with_call_paged_attention(
722722
@expectedFailure # tl.dot does not support embedding size less than 16
723723
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
724724
@common_utils.parametrize("dtype", test_dtypes_fast)
725-
def test_bw_decoding_fails(self, dtype):
725+
def test_bw_decoding_fails(self, device, dtype):
726726
make_kv = functools.partial(
727727
torch.randn,
728728
(2, 2, 128, 4),
729729
dtype=dtype,
730-
device=GPU_TYPE,
730+
device=device,
731731
requires_grad=True,
732732
)
733733
make_q = functools.partial(
734734
torch.randn,
735735
(2, 2, 8, 4),
736736
dtype=dtype,
737-
device=GPU_TYPE,
737+
device=device,
738738
requires_grad=True,
739739
)
740740
q, k, v, backward_grad = make_q(), make_kv(), make_kv(), make_q()
@@ -1004,20 +1004,20 @@ def mask_mod(b, h, q, kv):
10041004

10051005
@supported_platform
10061006
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
1007-
def test_non_divisible_multi_token_offset_mask_with_captured_buffer(self):
1007+
def test_non_divisible_multi_token_offset_mask_with_captured_buffer(self, device):
10081008
KV_S = S - 3
10091009
Q_S = 3
1010-
offset_kv = torch.randn(KV_S, device=GPU_TYPE, dtype=torch.bfloat16)
1011-
offset_q = torch.randn(Q_S, device=GPU_TYPE, dtype=torch.bfloat16)
1012-
offset_tensor = torch.tensor(S // 2 - 3, device=GPU_TYPE, dtype=torch.int32)
1010+
offset_kv = torch.randn(KV_S, device=device, dtype=torch.bfloat16)
1011+
offset_q = torch.randn(Q_S, device=device, dtype=torch.bfloat16)
1012+
offset_tensor = torch.tensor(S // 2 - 3, device=device, dtype=torch.int32)
10131013

10141014
def score_mod(score, b, h, q, kv):
10151015
return score + offset_kv[kv] + offset_q[q]
10161016

10171017
def mask_mod(b, h, q, kv):
10181018
return kv >= q + offset_tensor
10191019

1020-
block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S, device=self.device)
1020+
block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S, device=device)
10211021
self.run_test(Q_S=Q_S, KV_S=KV_S, block_mask=block_mask, score_mod=score_mod)
10221022

10231023
@supported_platform
@@ -1682,19 +1682,19 @@ def mask_mod(b, h, q, kv):
16821682
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
16831683
@common_utils.parametrize("dtype", test_dtypes)
16841684
@common_utils.parametrize("score_mod", [_identity, _causal])
1685-
def test_logsumexp_correctness(self, dtype, score_mod):
1685+
def test_logsumexp_correctness(self, device, dtype, score_mod):
16861686
make_kv = functools.partial(
16871687
torch.randn,
16881688
(B, Hkv, S, D),
16891689
dtype=dtype,
1690-
device=GPU_TYPE,
1690+
device=device,
16911691
requires_grad=True,
16921692
)
16931693
make_q = functools.partial(
16941694
torch.randn,
16951695
(B, Hkv, Hq // Hkv, D),
16961696
dtype=dtype,
1697-
device=GPU_TYPE,
1697+
device=device,
16981698
requires_grad=True,
16991699
)
17001700
q, k, v = make_q(), make_kv(), make_kv()
@@ -1734,19 +1734,19 @@ def eager_sdpa_hop(q, k, v, score_mod):
17341734

17351735
@supported_platform
17361736
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
1737-
def test_logsumexp_only_return(self):
1737+
def test_logsumexp_only_return(self, device):
17381738
make_q = functools.partial(
17391739
torch.randn,
17401740
(B, Hkv, Hq // Hkv, D),
17411741
dtype=torch.float32,
1742-
device=GPU_TYPE,
1742+
device=device,
17431743
requires_grad=True,
17441744
)
17451745
make_kv = functools.partial(
17461746
torch.randn,
17471747
(B, Hkv, S, D),
17481748
dtype=torch.float32,
1749-
device=GPU_TYPE,
1749+
device=device,
17501750
requires_grad=True,
17511751
)
17521752

@@ -1998,7 +1998,7 @@ def causal_mask(b, h, q, kv):
19981998
self._check_equal(golden_outs, ref_outs, paged_out, fudge_factor, "Out")
19991999

20002000

2001-
instantiate_device_type_tests(TestFlexDecoding, globals(), only_for=test_device)
2001+
instantiate_device_type_tests(TestFlexDecoding, globals(), only_for=test_device, allow_xpu=True)
20022002

20032003
if __name__ == "__main__":
20042004
from torch._inductor.test_case import run_tests

0 commit comments

Comments
 (0)
0