8000 generalize cuda bias ut code · pytorch/pytorch@d2ee416 · GitHub
[go: up one dir, main page]

Skip to content

Commit d2ee416

Browse files
committed
generalize cuda bias ut code
1 parent c810417 commit d2ee416

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

test/inductor/test_flex_attention.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor):
417417
)
418418

419419

420-
@large_tensor_test_class("2GB", device="cuda")
420+
@large_tensor_test_class("2GB", device=test_device[0])
421421
class TestFlexAttention(InductorTestCase):
422422
def setUp(self):
423423
super().setUp()
@@ -4531,7 +4531,7 @@ def create_inputs(S):
45314531
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
45324532

45334533

4534-
@large_tensor_test_class("2GB", device="cuda")
4534+
@large_tensor_test_class("2GB", device=test_device[0])
45354535
class TestPagedAttention(InductorTestCase):
45364536
def setUp(self):
45374537
super().setUp()
@@ -4982,7 +4982,7 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
49824982

49834983

49844984
@supports_learnable_bias
4985-
@large_tensor_test_class("2GB", device="cuda")
4985+
@large_tensor_test_class("2GB", device=test_device[0])
49864986
class TestLearnableBiases(InductorTestCase):
49874987
def setUp(self):
49884988
super().setUp()

torch/testing/_internal/common_device_type.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1342,8 +1342,8 @@ def dep_fn(self, *args, **kwargs):
13421342
# an additional array of the same size as the input.
13431343
if inductor and torch._inductor.config.cpp_wrapper and _device != "cpu":
13441344
size_bytes *= 2
1345-
1346-
if not _has_sufficient_memory(_device, size_bytes):
1345+
# TODO: Memory availability checks for Intel GPU
1346+
if device != "xpu" and not _has_sufficient_memory(_device, size_bytes):
13471347
raise unittest.SkipTest(f"Insufficient {_device} memory")
13481348

13491349
return fn(self, *args, **kwargs)

0 commit comments

Comments
 (0)
0