8000 [FlexAttention] Remove old constraint that was causing assert failure · pytorch/pytorch@b65ff0e · GitHub
[go: up one dir, main page]

Skip to content

Commit b65ff0e

Browse files
committed
[FlexAttention] Remove old constraint that was causing assert failure
ghstack-source-id: 5dc7a4c Pull Request resolved: #151521
1 parent cd7bc60 commit b65ff0e

File tree

2 files changed

+73
-5
lines changed

2 files changed

+73
-5
lines changed

test/inductor/test_flex_attention.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2658,6 +2658,7 @@ def test_strided_backwards(self):
26582658
(1, 0, 2, 3), # Reverse order
26592659
(0, 2, 1, 3), # Mixed order
26602660
(2, 0, 1, 3), # Another mixed order
2661+
(0, 1, 3, 2) # Non contiguous last dim
26612662
],
26622663
)
26632664
@common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)])
@@ -2712,6 +2713,7 @@ def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape
27122713
(1, 0, 2, 3),
27132714
(0, 2, 1, 3),
27142715
(2, 0, 1, 3),
2716+
(0, 1, 3, 2)
27152717
],
27162718
)
27172719
@common_utils.parametrize("shape", [(2, 5, 128, 16), (4, 2, 64, 16)])
@@ -2754,6 +2756,75 @@ def test_flex_attention_backward_stride_ordering(self, mode, permute_order, shap
27542756
f"Mode: {mode}, Stride order mismatch for {name}: grad {input_stride_order}, input {orig_stride_order}.",
27552757
)
27562758

2759+
@supported_platform
2760+
def test_non_contiguous_last_dim(self, device):
2761+
"""Test flex_attention with tensors having non contiguous last dimension."""
2762+
B, H, S, D = 4, 8, 128, 64
2763+
dtype = torch.float16 if device == "cuda" else torch.float32
2764+
2765+
def create_non_unit_stride_tensor():
2766+
tensor = torch.randn(
2767+
(B, H, S, D),
2768+
dtype=dtype,
2769+
device=device,
2770+
)
2771+
# Column major in last 2 dims
2772+
return tensor.transpose(-1, -2).contiguous().transpose(-1, -2)
2773+
2774+
# Create tensors with non-unit stride
2775+
q = create_non_unit_stride_tensor()
2776+
k = create_non_unit_stride_tensor()
2777+
v = create_non_unit_stride_tensor()
2778+
2779+
if not self.test_inference_only:
2780+
q.requires_grad_(True)
2781+
k.requires_grad_(True)
2782+
v.requires_grad_(True)
2783+
2784+
# Verify last dimension has non-unit stride
2785+
self.assertNotEqual(q.stride()[-1], 1)
2786+
self.assertNotEqual(k.stride()[-1], 1)
2787+
self.assertNotEqual(v.stride()[-1], 1)
2788+
2789+
# Create clones for different computation paths
2790+
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
2791+
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
2792+
2793+
# Run with different precisions and compilation
2794+
golden_out = flex_attention(q_gold, k_gold, v_gold)
2795+
ref_out = flex_attention(q_ref, k_ref, v_ref)
2796+
2797+
flex_compiled = torch.compile(flex_attention, fullgraph=True)
2798+
compiled_out = flex_compiled(q, k, v)
2799+
2800+
# Check forward pass correctness
2801+
print(compiled_out)
2802+
self._check_out(golden_out, ref_out, compiled_out)
2803+
2804+
if not self.test_inference_only:
2805+
# For backward pass testing
2806+
backward_grad = torch.randn_like(ref_out)
2807+
2808+
golden_out.backward(backward_grad.to(torch.float64))
2809+
ref_out.backward(backward_grad)
2810+
compiled_out.backward(backward_grad)
2811+
2812+
# Check backward pass correctness
2813+
self._check_out_and_grad(
2814+
golden_out,
2815+
ref_out,
2816+
compiled_out,
2817+
q_gold,
2818+
q_ref,
2819+
q,
2820+
k_gold,
2821+
k_ref,
2822+
k,
2823+
v_gold,
2824+
v_ref,
2825+
v,
2826+
)
2827+
27572828
@supported_platform
27582829
@common_utils.parametrize("compile", [True, False])
27592830
def test_fully_masked_out_rows_0_check(self, device, compile: bool):

torch/_inductor/kernel/flex_attention.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,7 +1258,6 @@ def set_head_dim_values(
12581258
)
12591259

12601260

1261-
# TODO: We probably also need a layout constraint?
12621261
@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
12631262
def flex_attention(
12641263
query,
@@ -1413,11 +1412,9 @@ def flex_attention(
14131412
else:
14141413
kernel_options.setdefault("IS_DIVISIBLE", True)
14151414

1416-
# Reuse query strides for output layout despite different last dimension.
1417-
# This works because only the last dim differs and we check it is contiguous.
1415+
# NB it is okay that the v_head_dim is different
1416+
# We are using these to match fill order of the output.
14181417
q_strides = query.get_stride()
1419-
assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
1420-
14211418
# Construct output layout with strides matching the query.
14221419
out_size = [B, Hq, seq_len_q, v_head_dim]
14231420
out_strides = infer_dense_strides(out_size, q_strides)

0 commit comments

Comments
 (0)
0