8000 [FlexAttention] Remove Old Constraint on lastdim strides by drisspg · Pull Request #151959 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[FlexAttention] Remove Old Constraint on lastdim strides #151959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions test/inductor/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def temp_float32_matmul_precision(precision: str):

def skip_on_cpu(test_func):
"""Decorator to skip tests that are not supported on CPU."""
decorated_func = skipCPUIf(True, "Not supported on CUDA")(test_func)
decorated_func = skipCPUIf(True, "Not supported on CPU")(test_func)
return decorated_func


Expand Down Expand Up @@ -2851,6 +2851,7 @@ def test_strided_backwards(self):
(1, 0, 2, 3), # Reverse order
(0, 2, 1, 3), # Mixed order
(2, 0, 1, 3), # Another mixed order
(0, 1, 3, 2), # Non contiguous last dim
],
)
@common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)])
Expand Down Expand Up @@ -2899,12 +2900,7 @@ def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape
@common_utils.parametrize("mode", ["eager", "inductor"])
@common_utils.parametrize(
"permute_order",
[
(0, 1, 2, 3),
(1, 0, 2, 3),
(0, 2, 1, 3),
(2, 0, 1, 3),
],
[(0, 1, 2, 3), (1, 0, 2, 3), (0, 2, 1, 3), (2, 0, 1, 3), (0, 1, 3, 2)],
)
@common_utils.parametrize("shape", [(2, 5, 128, 16), (4, 2, 64, 16)])
def test_flex_attention_backward_stride_ordering(
Expand Down Expand Up @@ -2948,6 +2944,69 @@ def test_flex_attention_backward_stride_ordering(
f"Mode: {mode}, Stride order mismatch for {name}: grad {input_stride_order}, input {orig_stride_order}.",
)

@supported_platform
def test_non_contiguous_last_dim(self, device):
"""Test flex_attention with tensors having non contiguous last dimension."""
B, H, D = 4, 8, 64
dtype = torch.float16 if device == "cuda" else torch.float32
for S in [16, 64]:

def column_major_tensor():
tensor = torch.randn(
(B, H, S, D),
dtype=dtype,
device=device,
)
# Column major in last 2 dims
return tensor.transpose(-1, -2).contiguous().transpose(-1, -2)

q = column_major_tensor()
k = column_major_tensor()
v = column_major_tensor()

requires_grad = device in DEVICE_SUPPORTS_BACKWARDS
if requires_grad:
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)

self.assertNotEqual(q.stride()[-1], 1)
self.assertNotEqual(k.stride()[-1], 1)
self.assertNotEqual(v.stride()[-1], 1)

q_ref, k_ref, 10000 v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)

golden_out = flex_attention(q_gold, k_gold, v_gold)
ref_out = flex_attention(q_ref, k_ref, v_ref)

flex_compiled = torch.compile(flex_attention, fullgraph=True, dynamic=True)
compiled_out = flex_compiled(q, k, v)

self._check_out(golden_out, ref_out, compiled_out)

if requires_grad:
backward_grad = torch.randn_like(ref_out)

golden_out.backward(backward_grad.to(torch.float64))
ref_out.backward(backward_grad)
compiled_out.backward(backward_grad)

self._check_out_and_grad(
golden_out,
ref_out,
compiled_out,
q_gold,
q_ref,
q,
k_gold,
k_ref,
k,
v_gold,
v_ref,
v,
)

@supported_platform
@common_utils.parametrize("compile", [True, False])
def test_fully_masked_out_rows_0_check(self, device, compile: bool):
Expand Down
19 changes: 14 additions & 5 deletions torch/_inductor/kernel/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,15 @@ def check_cpu_supported():
return supported


def contiguous_last_dim(x):
"""Ensure that realized IR node has a contigous stride in the last dimension."""
strides = x.maybe_get_stride()
if strides and strides[-1] != 1:
contiguous_stride_order = list(reversed(range(len(x.get_size()))))
return ExternKernel.require_stride_order(x, contiguous_stride_order)
return x


def lower_cpu(
query,
key,
Expand Down Expand Up @@ -1092,6 +1101,9 @@ def convert_mask_graph_module(mask_graph):
if isinstance(item, TensorBox):
fake_buffers.append(item.data.data) # type: ignore[attr-defined]

# CPU kernel requires last dim to be contiguous
query, key, value = map(contiguous_last_dim, [query, key, value])

(
query,
key,
Expand Down Expand Up @@ -1258,7 +1270,6 @@ def set_head_dim_values(
)


# TODO: We probably also need a layout constraint?
@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
def flex_attention(
query,
Expand Down Expand Up @@ -1413,11 +1424,9 @@ def flex_attention(
else:
kernel_options.setdefault("IS_DIVISIBLE", True)

# Reuse query strides for output layout despite different last dimension.
# This works because only the last dim differs and we check it is contiguous.
# NB it is okay that the v_head_dim is different
# We are using these to match fill order of the output.
q_strides = query.get_stride()
assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"

# Construct output layout with strides matching the query.
out_size = [B, Hq, seq_len_q, v_head_dim]
out_strides = infer_dense_strides(out_size, q_strides)
Expand Down
Loading
0