8000 reformat · pytorch/pytorch@ee9682f · GitHub
[go: up one dir, main page]

Skip to content

Commit ee9682f

Browse files
committed
reformat
1 parent f79f883 commit ee9682f

File tree

2 files changed

+37
-58
lines changed

2 files changed

+37
-58
lines changed

test/inductor/test_flex_attention.py

+32-51
Original file line numberDiff line numberDiff line change
@@ -2539,9 +2539,7 @@ def mask(b, h, q, kv):
25392539
k.grad = None
25402540
v.grad = None
25412541

2542-
block_mask2 = create_block_mask(
2543-
mask, None, None, 2048, 2048, device=device
2544-
)
2542+
block_mask2 = create_block_mask(mask, None, None, 2048, 2048, device=device)
25452543
# Reuse the 1st version with q/k/v(seqlen=2048) and block_mask(seqlen=2048)
25462544
out2 = torch.compile(flex_attention, dynamic=True)(
25472545
q, k, v, block_mask=block_mask2
@@ -2732,7 +2730,9 @@ def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape
27322730
],
27332731
)
27342732
@common_utils.parametrize("shape", [(2, 5, 128, 16), (4, 2, 64, 16)])
2735-
def test_flex_attention_backward_stride_ordering(self, device, mode, permute_order, shape):
2733+
def test_flex_attention_backward_stride_ordering(
2734+
self, device, mode, permute_order, shape
2735+
):
27362736
from torch._inductor.ir import get_stride_order
27372737

27382738
dtype = torch.float32
@@ -2865,15 +2865,11 @@ def causal_mask(b, h, q_idx, kv_idx):
28652865
score_mod_sparse_flex = functools.partial(
28662866
flex_attention,
28672867
score_mod=causal,
2868-
block_mask=create_block_mask(
2869-
causal_mask, 1, 1, 2048, 2048, device=device
2870-
),
2868+
block_mask=create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device),
28712869
)
28722870
mask_mod_sparse_flex = functools.partial(
28732871
flex_attention,
2874-
block_mask=create_block_mask(
2875-
causal_mask, 1, 1, 2048, 2048, device=device
2876-
),
2872+
block_mask=create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device),
28772873
)
28782874
for attention_call in [
28792875
no_sparse_flex,
@@ -2892,9 +2888,7 @@ def causal_mask(b, h, q_idx, kv_idx):
28922888
)
28932889
for _ in range(3)
28942890
]
2895-
gradOut = torch.randn(
2896-
2, 2, 2048, 64, device=device, dtype=torch.float16
2897-
)
2891+
gradOut = torch.randn(2, 2, 2048, 64, device=device, dtype=torch.float16)
28982892
out_ref = torch.nn.functional.scaled_dot_product_attention(
28992893
*inputs, is_causal=True
29002894
)
@@ -2949,9 +2943,7 @@ def mod(b, h, q, kv):
29492943
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
29502944
def test_head_bias_req_grad(self, device):
29512945
B, H, S, D = 1, 4, 256, 64
2952-
bias = torch.randn(
2953-
H, device=device, dtype=torch.float16, requires_grad=True
2954-
)
2946+
bias = torch.randn(H, device=device, dtype=torch.float16, requires_grad=True)
29552947

29562948
bias_flex = bias.detach().clone().requires_grad_(True)
29572949

@@ -3017,9 +3009,7 @@ def rel_pos_1d(score, b, h, q_idx, kv_idx):
30173009

30183010
# 2-dimensional bias:
30193011
B, H, S, D = 1, 1, 256, 64
3020-
bias = torch.randn(
3021-
S, S, device=device, dtype=torch.float16, requires_grad=True
3022-
)
3012+
bias = torch.randn(S, S, device=device, dtype=torch.float16, requires_grad=True)
30233013

30243014
bias_flex = bias.detach().clone().requires_grad_(True)
30253015

@@ -3048,9 +3038,7 @@ def rel_pos_2d(score, b, h, q_idx, kv_idx):
30483038

30493039
# 2-dimensional bias + index multiple
30503040
B, H, S, D = 1, 1, 256, 64
3051-
bias = torch.randn(
3052-
S, S, device=device, dtype=torch.float16, requires_grad=True
3053-
)
3041+
bias = torch.randn(S, S, device=device, dtype=torch.float16, requires_grad=True)
30543042

30553043
bias_flex = bias.detach().clone().requires_grad_(True)
30563044

@@ -3079,9 +3067,7 @@ def rel_pos_2d(score, b, h, q_idx, kv_idx):
30793067

30803068
# 2-dimensional bias + transposed:
30813069
B, H, S, D = 1, 1, 256, 64
3082-
bias = torch.randn(
3083-
S, S, device=device, dtype=torch.float16, requires_grad=True
3084-
)
3070+
bias = torch.randn(S, S, device=device, dtype=torch.float16, requires_grad=True)
30853071

30863072
bias_flex = bias.detach().clone().requires_grad_(True)
30873073

@@ -3868,7 +3854,7 @@ def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
38683854
joint_graph,
38693855
expected_joint_graph,
38703856
)
3871-
3857+
38723858
@supported_platform
38733859
@unittest.skipIf(TEST_ON_CUDA, "Testing CPU error message")
38743860
def test_cpu_error_message_return_lse(self, device):
@@ -3992,9 +3978,7 @@ def test_block_mask_attributes(self, device):
39923978
def causal_mask(b, h, q, kv):
39933979
return (q + (offset[b] * 128)) >= kv
39943980

3995-
block_mask = create_block_mask(
3996-
causal_mask, 4, 2, 2048, 2048, device=device
3997-
)
3981+
block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048, device=device)
39983982
self.assertEqual(block_mask.shape, (4, 2, 2048, 2048))
39993983
self.assertEqual(block_mask[0].shape, (2, 2048, 2048))
40003984
self.assertEqual(block_mask[0, 0].shape, (2048, 2048))
@@ -4005,9 +3989,7 @@ def causal_mask(b, h, q, kv):
40053989
self.assertEqual(block_mask.sparsity(), block_mask[1].sparsity())
40063990

40073991
offset = torch.arange(8, device=device)
4008-
block_mask = create_block_mask(
4009-
causal_mask, 8, 1, 2048, 2048, device=device
4010-
)
3992+
block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048, device=device)
40113993
self.assertEqual(block_mask.sparsity(), 29.1015625)
40123994
self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity())
40133995
self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity())
@@ -4163,9 +4145,7 @@ def test_block_mask_viz(self, device):
41634145
def causal_mask(b, h, q, kv):
41644146
return q >= kv
41654147

4166-
block_mask = create_block_mask(
4167-
causal_mask, 1, 1, 2048, 2048, device=device
4168-
)
4148+
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device)
41694149

41704150
def replace_non_printable(s):
41714151
def replace(c):
@@ -4368,9 +4348,7 @@ def test_no_q_info(self, device, compile: bool):
43684348
def causal_mask(b, h, q_idx, kv_idx):
43694349
return q_idx >= kv_idx
43704350

4371-
block_mask = create_block_mask(
4372-
causal_mask, 1, 1, 2048, 2048, device=device
4373-
)
4351+
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device)
43744352
# manually set q_num_blocks and q_indices to None
43754353
block_mask.q_num_blocks = None
43764354
block_mask.q_indices = None
@@ -4494,9 +4472,7 @@ def test_eager_tracing_correctness(self, device):
44944472
seq_len = 256
44954473
batch_size = 1
44964474

4497-
make_tensor = functools.partial(
4498-
torch.randn, device=device, dtype=torch.float16
4499-
)
4475+
make_tensor = functools.partial(torch.randn, device=device, dtype=torch.float16)
45004476
q = make_tensor(*(batch_size, q_heads, seq_len, qk_dims))
45014477
k = make_tensor(*(batch_size, kv_heads, seq_len, qk_dims))
45024478
v = make_tensor(*(batch_size, kv_heads, seq_len, v_dims))
@@ -4556,16 +4532,12 @@ def create_inputs(S):
45564532
)
45574533
return q, k, v
45584534

4559-
block_mask = create_block_mask(
4560-
mask_mod, None, None, 1024, 1024, device=device
4561-
)
4535+
block_mask = create_block_mask(mask_mod, None, None, 1024, 1024, device=device)
45624536
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
45634537
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
45644538
flex_attention_call(*create_inputs(2048), block_mask=block_mask)
45654539

4566-
block_mask = create_block_mask(
4567-
mask_mod, None, None, 1023, 1023, device=device
4568-
)
4540+
block_mask = create_block_mask(mask_mod, None, None, 1023, 1023, device=device)
45694541
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
45704542
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
45714543

@@ -5634,10 +5606,19 @@ def sliding_window(b, h, q_idx, kv_idx, val):
56345606
opt_fn(sliding_window2, None, None, 1024, 1024)
56355607

56365608

5637-
instantiate_device_type_tests(TestFlexAttention, globals(), only_for=test_device, allow_xpu=True)
5638-
instantiate_device_type_tests(TestPagedAttention, globals(), only_for=test_device, allow_xpu=True)
5639-
instantiate_device_type_tests(TestBlockMask, globals(), only_for=test_device, allow_xpu=True)
5640-
instantiate_device_type_tests(TestLearnableBiases, globals(), only_for=test_device, allow_xpu=True)
5609+
instantiate_device_type_tests(
5610+
TestFlexAttention, globals(), only_for=test_device, allow_xpu=True
5611+
)
5612+
instantiate_device_type_tests(
5613+
TestPagedAttention, globals(), only_for=test_device, allow_xpu=True
5614+
)
5615+
instantiate_device_type_tests(
5616+
TestBlockMask, globals(), only_for=test_device, allow_xpu=True
5617+
)
5618+
instantiate_device_type_tests(
5619+
TestLearnableBiases, globals(), only_for=test_device, allow_xpu=True
5620+
)
5621+
56415622

56425623
if __name__ == "__main__":
56435624
from torch._inductor.test_case import run_tests

test/inductor/test_flex_decoding.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
flex_attention_supported_platform as supported_platform,
2828
instantiate_device_type_tests,
2929
)
30-
from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU
31-
from torch.utils._triton import has_triton
30+
from torch.testing._internal.inductor_utils import HAS_GPU
3231

3332

3433
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
@@ -42,10 +41,7 @@
4241
and torch.utils._triton.has_triton()
4342
and torch.cuda.get_device_capability() >= (8, 0)
4443
)
45-
TEST_ON_XPU = (
46-
torch.xpu.is_available()
47-
and torch.utils._triton.has_triton()
48-
)
44+
TEST_ON_XPU = torch.xpu.is_available() and torch.utils._triton.has_triton()
4945

5046
if HAS_GPU:
5147
if TEST_ON_CUDA:
@@ -1998,7 +1994,9 @@ def causal_mask(b, h, q, kv):
19981994
self._check_equal(golden_outs, ref_outs, paged_out, fudge_factor, "Out")
19991995

20001996

2001-
instantiate_device_type_tests(TestFlexDecoding, globals(), only_for=test_device, allow_xpu=True)
1997+
instantiate_device_type_tests(
1998+
TestFlexDecoding, globals(), only_for=test_device, allow_xpu=True
1999+
)
20022000

20032001
if __name__ == "__main__":
20042002
from torch._inductor.test_case import run_tests

0 commit comments

Comments
 (0)
0