10000 [FlexAttention] Remove old constraint that was causing assert failure · pytorch/pytorch@50cb7ca · GitHub
[go: up one dir, main page]

Skip to content

Commit 50cb7ca

Browse files
committed
[FlexAttention] Remove old constraint that was causing assert failure
ghstack-source-id: de1f894 Pull Request resolved: #151521
1 parent cd7bc60 commit 50cb7ca

File tree

2 files changed

+65
-11
lines changed

2 files changed

+65
-11
lines changed

test/inductor/test_flex_attention.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2658,6 +2658,7 @@ def test_strided_backwards(self):
26582658
(1, 0, 2, 3), # Reverse order
26592659
(0, 2, 1, 3), # Mixed order
26602660
(2, 0, 1, 3), # Another mixed order
2661+
(0, 1, 3, 2), # Non contiguous last dim
26612662
],
26622663
)
26632664
@common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)])
@@ -2707,12 +2708,7 @@ def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape
27072708
@common_utils.parametrize("mode", ["eager", "inductor"])
27082709
@common_utils.parametrize(
27092710
"permute_order",
2710-
[
2711-
(0, 1, 2, 3),
2712-
(1, 0, 2, 3),
2713-
(0, 2, 1, 3),
2714-
(2, 0, 1, 3),
2715-
],
2711+
[(0, 1, 2, 3), (1, 0, 2, 3), (0, 2, 1, 3), (2, 0, 1, 3), (0, 1, 3, 2)],
27162712
)
27172713
@common_utils.parametrize("shape", [(2, 5, 128, 16), (4, 2, 64, 16)])
27182714
def test_flex_attention_backward_stride_ordering(self, mode, permute_order, shape):
@@ -2754,6 +2750,67 @@ def test_flex_attention_backward_stride_ordering(self, mode, permute_order, shap
27542750
f"Mode: {mode}, Stride order mismatch for {name}: grad {input_stride_order}, input {orig_stride_order}.",
27552751
)
27562752

2753+
@supported_platform
2754+
def test_non_contiguous_last_dim(self, device):
2755+
"""Test flex_attention with tensors having non contiguous last dimension."""
2756+
B, H, S, D = 4, 8, 128, 64
2757+
dtype = torch.float16 if device == "cuda" else torch.float32
2758+
2759+
def column_major_tensor():
2760+
tensor = torch.randn(
2761+
(B, H, S, D),
2762+
dtype=dtype,
2763+
device=device,
2764+
)
2765+
# Column major in last 2 dims
2766+
return tensor.transpose(-1, -2).contiguous().transpose(-1, -2)
2767+
2768+
q = column_major_tensor()
2769+
k = column_major_tensor()
2770+
v = column_major_tensor()
2771+
2772+
if not self.test_inference_only:
2773+
q.requires_grad_(True)
2774+
k.requires_grad_(True)
2775+
v.requires_grad_(True)
2776+
2777+
self.assertNotEqual(q.stride()[-1], 1)
2778+
self.assertNotEqual(k.stride()[-1], 1)
2779+
self.assertNotEqual(v.stride()[-1], 1)
2780+
2781+
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
2782+
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
2783+
2784+
golden_out = flex_attention(q_gold, k_gold, v_gold)
2785+
ref_out = flex_attention(q_ref, k_ref, v_ref)
2786+
2787+
flex_compiled = torch.compile(flex_attention, fullgraph=True)
2788+
compiled_out = flex_compiled(q, k, v)
2789+
2790+
self._check_out(golden_out, ref_out, compiled_out)
2791+
2792+
if not self.test_inference_only:
2793+
backward_grad = torch.randn_like(ref_out)
2794+
2795+
golden_out.backward(backward_grad.to(torch.float64))
2796+
ref_out.backward(backward_grad)
2797+
compiled_out.backward(backward_grad)
2798+
2799+
self._check_out_and_grad(
2800+
golden_out,
2801+
ref_out,
2802+
compiled_out,
2803+
q_gold,
2804+
q_ref,
2805+
q,
2806+
k_gold,
2807+
k_ref,
2808+
k,
2809+
v_gold,
2810+
v_ref,
2811+
v,
2812+
)
2813+
27572814
@supported_platform
27582815
@common_utils.parametrize("compile", [True, False])
27592816
def test_fully_masked_out_rows_0_check(self, device, compile: bool):

torch/_inductor/kernel/flex_attention.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,7 +1258,6 @@ def set_head_dim_values(
12581258
)
12591259

12601260

1261-
# TODO: We probably also need a layout constraint?
12621261
@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
12631262
def flex_attention(
12641263
query,
@@ -1413,11 +1412,9 @@ def flex_attention(
14131412
else:
14141413
kernel_options.setdefault("IS_DIVISIBLE", True)
14151414

1416-
# Reuse query strides for output layout despite different last dimension.
1417-
# This works because only the last dim differs and we check it is contiguous.
1415+
# NB it is okay that the v_head_dim is different
1416+
# We are using these to match fill order of the output.
14181417
q_strides = query.get_stride()
1419-
assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
1420-
14211418
# Construct output layout with strides matching the query.
14221419
out_size = [B, Hq, seq_len_q, v_head_dim]
14231420
out_strides = infer_dense_strides(out_size, q_strides)

0 commit comments

Comments
 (0)
0