8000 enable cuda-specific ut · pytorch/pytorch@f7cff4e · GitHub
[go: up one dir, main page]

Skip to content

Commit f7cff4e

Browse files
committed
enable cuda-specific ut
1 parent ee9682f commit f7cff4e

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

test/inductor/test_flex_attention.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,10 +3342,10 @@ def global_causal(b, h, q_idx, kv_idx):
33423342
def test_mixed_device_error_message(self, device):
33433343
# Create tensors on different devices
33443344
cpu_tensor = torch.randn(2, 2, 128, 16, device="cpu")
3345-
cuda_tensor = torch.randn(2, 2, 128, 16, device=device)
3345+
gpu_tensor = torch.randn(2, 2, 128, 16, device=device)
33463346

33473347
# Use different devices for query, key, and value
3348-
query, key, value = cpu_tensor, cuda_tensor, cpu_tensor
3348+
query, key, value = cpu_tensor, gpu_tensor, cpu_tensor
33493349

33503350
expected_error_message = (
33513351
"Expected query, key, and value to have the same device type, "
@@ -3924,9 +3924,9 @@ def test_validate_small_embedding_size_error_message(self, device):
39243924
not has_triton() or not HAS_WARP_SPEC,
39253925
reason="FBCODE Triton is required for this test",
39263926
)
3927-
def test_triton_template_warp_specialization(self):
3927+
def test_triton_template_warp_specialization(self, device):
39283928
def make_tensor():
3929-
return torch.rand(4, 16, 4096, 64, device="cuda", dtype=torch.bfloat16)
3929+
return torch.rand(4, 16, 4096, 64, device=device, dtype=torch.bfloat16)
39303930

39313931
q, k, v = make_tensor(), make_tensor(), make_tensor()
39323932
flex_compiled = torch.compile(flex_attention, fullgraph=True)
@@ -4071,16 +4071,17 @@ def causal_mask(b, h, q, kv):
40714071

40724072
@supported_platform
40734073
def test_block_mask_device_change(self, device):
4074+
device = torch.device(device)
40744075
offset = torch.zeros(8, device=device)
40754076

40764077
def causal_mask(b, h, q, kv):
40774078
return (q + (offset[b] * 128)) >= kv
40784079

40794080
block_mask = create_block_mask(causal_mask, 1, 1, 512, 512, device=device)
4080-
assert block_mask.kv_indices.is_cuda
4081-
assert block_mask.kv_num_blocks.is_cuda
4082-
assert block_mask.q_indices.is_cuda
4083-
assert block_mask.q_num_blocks.is_cuda
4081+
assert block_mask.kv_indices.device.type == device.type
4082+
assert block_mask.kv_num_blocks.device.type == device.type
4083+
assert block_mask.q_indices.device.type == device.type
4084+
assert block_mask.q_num_blocks.device.type == device.type
40844085

40854086
block_mask = block_mask.to("cpu")
40864087
assert block_mask.kv_indices.is_cpu
@@ -4089,10 +4090,10 @@ def causal_mask(b, h, q, kv):
40894090
assert block_mask.q_num_blocks.is_cpu
40904091

40914092
block_mask = block_mask.to(device)
4092-
assert block_mask.kv_indices.is_cuda
4093-
assert block_mask.kv_num_blocks.is_cuda
4094-
assert block_mask.q_indices.is_cuda
4095-
assert block_mask.q_num_blocks.is_cuda
4093+
assert block_mask.kv_indices.device.type == device.type
4094+
assert block_mask.kv_num_blocks.device.type == device.type
4095+
assert block_mask.q_indices.device.type == device.type
4096+
assert block_mask.q_num_blocks.device.type == device.type
40964097

40974098
@supported_platform
40984099
def test_compiling_create_block_mask(self, device):
@@ -4984,9 +4985,12 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
49844985

49854986

49864987
supports_learnable_bias = unittest.skipUnless(
4987-
(torch.cuda.is_available() and has_triton())
4988-
and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip),
4989-
"Requires Triton + A100 or Triton + ROCm",
4988+
(torch.xpu.is_available() and has_triton())
4989+
or (
4990+
(torch.cuda.is_available() and has_triton())
4991+
and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip)
4992+
),
4993+
"Requires Triton + A100 or Triton + ROCm or Triton + XPU",
49904994
)
49914995

49924996

@@ -5554,10 +5558,10 @@ def bias_func(score, b, h, q_idx, kv_idx):
55545558
out_eager, out_compiled, out_gold, (bias,), names=["out", "bias"]
55555559
)
55565560

5557-
def test_flex_attention_with_dynamic_max_autotune(self):
5558-
query = torch.randn(2, 16, 512, 64, device="cuda")
5559-
key = torch.randn(2, 16, 512, 64, device="cuda")
5560-
value = torch.randn(2, 16, 512, 64, device="cuda")
5561+
def test_flex_attention_with_dynamic_max_autotune(self, device):
5562+
query = torch.randn(2, 16, 512, 64, device=device)
5563+
key = torch.randn(2, 16, 512, 64, device=device)
5564+
value = torch.randn(2, 16, 512, 64, device=device)
55615565
query.requires_grad = True
55625566
key.requires_grad = True
55635567
value.requires_grad = True
@@ -5571,7 +5575,9 @@ def causal(b, h, m, n):
55715575
return m >= n
55725576

55735577
mask_shape = (1, 1, M, N)
5574-
block_mask = torch.compile(create_block_mask)(causal, *mask_shape, "cuda")
5578+
block_mask = torch.compile(create_block_mask)(
5579+
causal, *mask_shape, device=device
5580+
)
55755581

55765582
compiled_sdpa = torch.compile(
55775583
flex_attention, dynamic=True, mode="max-autotune-no-cudagraphs"
@@ -5598,7 +5604,7 @@ def sliding_window(b, h, q_idx, kv_idx, val):
55985604
return (q_idx - kv_idx).abs() < val
55995605

56005606
sliding_window2 = functools.partial(
5601-
sliding_window, val=torch.randn((), device="cuda")
5607+
sliding_window, val=torch.randn((), device=device)
56025608
)
56035609
opt_fn = torch.compile(create_block_mask, fullgraph=True)
56045610
create_block_mask(sliding_window2, None, None, 1024, 1024)

0 commit comments

Comments
 (0)
0