@@ -3342,10 +3342,10 @@ def global_causal(b, h, q_idx, kv_idx):
3342
3342
def test_mixed_device_error_message (self , device ):
3343
3343
# Create tensors on different devices
3344
3344
cpu_tensor = torch .randn (2 , 2 , 128 , 16 , device = "cpu" )
3345
- cuda_tensor = torch .randn (2 , 2 , 128 , 16 , device = device )
3345
+ gpu_tensor = torch .randn (2 , 2 , 128 , 16 , device = device )
3346
3346
3347
3347
# Use different devices for query, key, and value
3348
- query , key , value = cpu_tensor , cuda_tensor , cpu_tensor
3348
+ query , key , value = cpu_tensor , gpu_tensor , cpu_tensor
3349
3349
3350
3350
expected_error_message = (
3351
3351
"Expected query, key, and value to have the same device type, "
@@ -3924,9 +3924,9 @@ def test_validate_small_embedding_size_error_message(self, device):
3924
3924
not has_triton () or not HAS_WARP_SPEC ,
3925
3925
reason = "FBCODE Triton is required for this test" ,
3926
3926
)
3927
- def test_triton_template_warp_specialization (self ):
3927
+ def test_triton_template_warp_specialization (self , device ):
3928
3928
def make_tensor ():
3929
- return torch .rand (4 , 16 , 4096 , 64 , device = "cuda" , dtype = torch .bfloat16 )
3929
+ return torch .rand (4 , 16 , 4096 , 64 , device = device , dtype = torch .bfloat16 )
3930
3930
3931
3931
q , k , v = make_tensor (), make_tensor (), make_tensor ()
3932
3932
flex_compiled = torch .compile (flex_attention , fullgraph = True )
@@ -4071,16 +4071,17 @@ def causal_mask(b, h, q, kv):
4071
4071
4072
4072
@supported_platform
4073
4073
def test_block_mask_device_change (self , device ):
4074
+ device = torch .device (device )
4074
4075
offset = torch .zeros (8 , device = device )
4075
4076
4076
4077
def causal_mask (b , h , q , kv ):
4077
4078
return (q + (offset [b ] * 128 )) >= kv
4078
4079
4079
4080
block_mask = create_block_mask (causal_mask , 1 , 1 , 512 , 512 , device = device )
4080
- assert block_mask .kv_indices .is_cuda
4081
- assert block_mask .kv_num_blocks .is_cuda
4082
- assert block_mask .q_indices .is_cuda
4083
- assert block_mask .q_num_blocks .is_cuda
4081
+ assert block_mask .kv_indices .device . type == device . type
4082
+ assert block_mask .kv_num_blocks .device . type == device . type
4083
+ assert block_mask .q_indices .device . type == device . type
4084
+ assert block_mask .q_num_blocks .device . type == device . type
4084
4085
4085
4086
block_mask = block_mask .to ("cpu" )
4086
4087
assert block_mask .kv_indices .is_cpu
@@ -4089,10 +4090,10 @@ def causal_mask(b, h, q, kv):
4089
4090
assert block_mask .q_num_blocks .is_cpu
4090
4091
4091
4092
block_mask = block_mask .to (device )
4092
- assert block_mask .kv_indices .is_cuda
4093
- assert block_mask .kv_num_blocks .is_cuda
4094
- assert block_mask .q_indices .is_cuda
4095
- assert block_mask .q_num_blocks .is_cuda
4093
+ assert block_mask .kv_indices .device . type == device . type
4094
+ assert block_mask .kv_num_blocks .device . type == device . type
4095
+ assert block_mask .q_indices .device . type == device . type
4096
+ assert block_mask .q_num_blocks .device . type == device . type
4096
4097
4097
4098
@supported_platform
4098
4099
def test_compiling_create_block_mask (self , device ):
@@ -4984,9 +4985,12 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
4984
4985
4985
4986
4986
4987
supports_learnable_bias = unittest .skipUnless (
4987
- (torch .cuda .is_available () and has_triton ())
4988
- and (torch .cuda .get_device_capability () == (8 , 0 ) or torch .version .hip ),
4989
- "Requires Triton + A100 or Triton + ROCm" ,
4988
+ (torch .xpu .is_available () and has_triton ())
4989
+ or (
4990
+ (torch .cuda .is_available () and has_triton ())
4991
+ and (torch .cuda .get_device_capability () == (8 , 0 ) or torch .version .hip )
4992
+ ),
4993
+ "Requires Triton + A100 or Triton + ROCm or Triton + XPU" ,
4990
4994
)
4991
4995
4992
4996
@@ -5554,10 +5558,10 @@ def bias_func(score, b, h, q_idx, kv_idx):
5554
5558
out_eager , out_compiled , out_gold , (bias ,), names = ["out" , "bias" ]
5555
5559
)
5556
5560
5557
- def test_flex_attention_with_dynamic_max_autotune (self ):
5558
- query = torch .randn (2 , 16 , 512 , 64 , device = "cuda" )
5559
- key = torch .randn (2 , 16 , 512 , 64 , device = "cuda" )
5560
- value = torch .randn (2 , 16 , 512 , 64 , device = "cuda" )
5561
+ def test_flex_attention_with_dynamic_max_autotune (self , device ):
5562
+ query = torch .randn (2 , 16 , 512 , 64 , device = device )
5563
+ key = torch .randn (2 , 16 , 512 , 64 , device = device )
5564
+ value = torch .randn (2 , 16 , 512 , 64 , device = device )
5561
5565
query .requires_grad = True
5562
5566
key .requires_grad = True
5563
5567
value .requires_grad = True
@@ -5571,7 +5575,9 @@ def causal(b, h, m, n):
5571
5575
return m >= n
5572
5576
5573
5577
mask_shape = (1 , 1 , M , N )
5574
- block_mask = torch .compile (create_block_mask )(causal , * mask_shape , "cuda" )
5578
+ block_mask = torch .compile (create_block_mask )(
5579
+ causal , * mask_shape , device = device
5580
+ )
5575
5581
5576
5582
compiled_sdpa = torch .compile (
5577
5583
flex_attention , dynamic = True , mode = "max-autotune-no-cudagraphs"
@@ -5598,7 +5604,7 @@ def sliding_window(b, h, q_idx, kv_idx, val):
5598
5604
return (q_idx - kv_idx ).abs () < val
5599
5605
5600
5606
sliding_window2 = functools .partial (
5601
- sliding_window , val = torch .randn ((), device = "cuda" )
5607
+ sliding_window , val = torch .randn ((), device = device )
5602
5608
)
5603
5609
opt_fn = torch .compile (create_block_mask , fullgraph = True )
5604
5610
create_block_mask (sliding_window2 , None , None , 1024 , 1024 )
0 commit comments