8000 [Intel GPU] Enable SDPA on XPU by DDEle · Pull Request #147614 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Intel GPU] Enable SDPA on XPU #147614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aten/src/ATen/native/mkldnn/xpu/Attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,6 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
philox_offset,
debug_attn_mask);
}

REGISTER_XPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_xpu);
} // namespace at::native
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14842,6 +14842,7 @@
Meta: _fused_sdp_choice_meta
CPU, NestedTensorCPU: _fused_sdp_choice_cpp
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
XPU: _fused_sdp_choice_xpu
tags: nondeterministic_seeded

- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor)
Expand All @@ -14867,6 +14868,7 @@
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable
XPU: _scaled_dot_product_fused_attention_overrideable_xpu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg is this the right place to override it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think so - at least not from within core. This is meant as a generic op that can be used to register different backends to through privateuse1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD, @drisspg, the original idea intends to avoid adding a new op. Would you be okay with adding an operation like _scaled_dot_product_mkldnn_attention, which is similar to _scaled_dot_product_cudnn_attention?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed things w/ Alban, It is okay to add the XPU dispatch to any of the available SDPA apis. Which ever one most clostly aligns w/ what you need for forward backwards makes the most sense.

In that context this change seems fine

tags: nondeterministic_seeded

- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
Expand Down
171 changes: 171 additions & 0 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3838,6 +3838,176 @@ def rand_nt(sequence_list, num_heads, head_dim):
}
)

class TestSDPAXpuOnly(NNTestCase):
""" Used to test XPU only functionality of scaled_dot_product_attention
Mostly migrate from TestSDPACudaOnly in test/test_transformers.py

Note that as SDPBackend.OVERRIDEABLE is not managed by sdpa_kernel so that
math ref has to be called explicitly via torch.ops.aten._scaled_dot_product_attention_math.
"""

@parametrize("type", ["dense"])
@parametrize("dropout", [0.0, 0.7])
@parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.half])
@skipIfTorchDynamo()
def test_fused_sdp_choice_xpu(self, device, type: str, dropout: float, dtype: torch.dtype):
# Migrate from test_fused_sdp_choice_cpu
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype)
size = SdpaShape(2, 8, 128, 64)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
if dropout > 0.0 or dtype not in [torch.float32, torch.bfloat16, torch.float16]:
assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.MATH.value
else:
assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.OVERRIDEABLE.value

def test_fused_attention_different_dk_dv(self, device):
dtype = torch.bfloat16
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64
q_shape = SdpaShape(batch, num_heads, 1, head_dim_k)
k_shape = SdpaShape(batch, num_heads, 2, head_dim_k)
v_shape = SdpaShape(batch, num_heads, 2, head_dim_v)
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)

# test that we do not dispatch to onednn for an unsupported case
actual = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)

math_ref = torch.ops.aten._scaled_dot_product_attention_math(
query.float(), key.float(), value.float(), attn_mask=None, dropout_p=0.0, is_causal=False)[0]

self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)

def test_onednn_attention_fail_d256(self, device):
# Test that onednn graph attention dispatching correctly bails out on d > 256
b, h = 1, 2
s_q, s_kv = 128, 128
d_qk, d_v = 512, 512

q = torch.randn(b, h, s_q, d_qk, device=device, dtype=torch.bfloat16)
k = torch.randn(b, h, s_kv, d_qk, device=device, dtype=torch.bfloat16)
v = torch.randn(b, h, s_kv, d_v, device=device, dtype=torch.bfloat16)

with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]):
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
_ = F.scaled_dot_product_attention(q, k, v)

@parametrize("type", ["dense"])
@parametrize("is_contiguous", [True, False])
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)

batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)

# Test Packed
qkv = make_tensor(shape)
query, key, value = qkv.chunk(3, dim=-1)

query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)

if is_contiguous:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]):
actual = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
query.contiguous(), key.contiguous(), value.contiguous(), attn_mask=None, dropout_p=0.0, is_causal=False)[0]

self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)

@parametrize("fused_kernel", [SDPBackend.MATH, SDPBackend.OVERRIDEABLE])
@parametrize("dtype", [torch.half, torch.bfloat16, torch.float32])
@parametrize("batch_size,n_head,q_size,kv_size,head_dim", [
(2, 5, 9216, 9216, 64),
(2, 5, 9216, 77, 64),
(2, 10, 2304, 2304, 64),
(2, 10, 2304, 77, 64),
(2, 20, 576, 576, 64),
(2, 20, 576, 77, 64),
(2, 20, 144, 144, 64),
(2, 20, 144, 77, 64),
(1, 32, 1, 32, 128),
(4, 32, 1, 32, 128),
(1, 32, 32, 32, 128),
(4, 32, 32, 32, 128),
(1, 32, 2016, 2016, 128),
(4, 32, 2016, 2016, 128),
])
@parametrize("mask_type", ["float", "causal"])
@parametrize("train", [False])
def test_scaled_dot_product_fused_attention_mask_vs_math(
self,
device,
fused_kernel,
dtype,
batch_size,
q_size,
kv_size,
n_head,
head_dim,
mask_type,
train,
):
# Migrate from TestSDPACpuOnly
tol = Tolerances(1e-5, 5e-6)
if dtype is torch.bfloat16:
tol = Tolerances(5e-2, 5e-2)
if dtype is torch.float16:
tol = Tolerances(1e-2, 1e-2)
mask_shape = [batch_size, 1, 1, kv_size]
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
q_shape = SdpaShape(batch_size, n_head, q_size, head_dim)
kv_shape = SdpaShape(batch_size, n_head, kv_size, head_dim)
q = make_tensor(q_shape)
k = make_tensor(kv_shape)
v = make_tensor(kv_shape)
q2, k2, v2 = q.clone(), k.clone(), v.clone()

if train:
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
q2.requires_grad_(True)
k2.requires_grad_(True)
v2.requires_grad_(True)

# (B, nh, T, hs)
q = q.view(batch_size, q_size, n_head, head_dim).transpose(1, 2)
k = k.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
v = v.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
attn_mask = None
is_causal = False
if mask_type == "bool":
attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
elif mask_type == "float":
attn_mask = torch.randn(mask_shape, dtype=dtype, device=device)
elif mask_type == "causal":
is_causal = True

q2, k2, v2 = q2.float(), k2.float(), v2.float()
q2 = q2.view(batch_size, q_size, n_head, head_dim).transpose(1, 2)
k2 = k2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
v2 = v2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
attn_mask2 = attn_mask.float() if attn_mask is not None else None

if fused_kernel == SDPBackend.MATH:
actual = torch.ops.aten._scaled_dot_product_attention_math(
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)[0]
elif fused_kernel == SDPBackend.OVERRIDEABLE:
actual = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
q, k, v, attn_bias=attn_mask, dropout_p=0.0, is_causal=is_causal)[0]

math_ref = torch.ops.aten._scaled_dot_product_attention_math(
q2, k2, v2, attn_mask=attn_mask2, dropout_p=0.0, is_causal=is_causal)[0]

self.assertEqual(actual.float(), math_ref, atol=tol.atol, rtol=tol.rtol)


class TestAttnBias(NNTestCase):

Expand Down Expand Up @@ -4014,6 +4184,7 @@ def test_is_causal_and_mask_fails(self, device):
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
instantiate_device_type_tests(TestSDPACpuOnly, globals(), only_for=("cpu"))
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
instantiate_device_type_tests(TestSDPAXpuOnly, globals(), only_for="xpu", allow_xpu=True)

if __name__ == '__main__':
run_tests()
10 changes: 10 additions & 0 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2775,6 +2775,16 @@ def is_aligned(x):
sdpa_constraint,
warn=False,
)
make_fallback(
aten._scaled_dot_product_fused_attention_overrideable.default,
sdpa_constraint,
warn=False,
)
make_fallback(
aten._scaled_dot_product_fused_attention_overrideable_backward.default,
sdpa_constraint,
warn=False,
)
make_fallback(aten._flash_attention_forward.default, sdpa_constraint)
make_fallback(aten._flash_attention_backward.default, sdpa_constraint)
make_fallback(aten._efficient_attention_forward.default, sdpa_constraint)
Expand Down
41 changes: 41 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5607,6 +5607,47 @@ def meta__scaled_dot_product_cudnn_attention(
)


@register_meta([aten._scaled_dot_product_fused_attention_overrideable])
def meta__scaled_dot_product_fused_attention_overrideable(
query: Tensor,
key: Tensor,
value: Tensor,
attn_bias: Optional[Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
scale: Optional[float] = None,
):
B = query.size(0)
H = query.size(1)
S_Q = query.size(2)
S_KV = key.size(2)
D_V = value.size(-1)

res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device)
logsum_exp = torch.empty(
(B, H, S_Q),
dtype=torch.float,
device=query.device,
)

# See Note [Seed and Offset]
seed = torch.empty((), dtype=torch.long, device="meta")
offset = torch.empty((), dtype=torch.long, device="meta")

return (
res,
logsum_exp,
None,
None,
S_Q,
S_KV,
seed,
offset,
None,
)


@register_meta(
[
aten._scaled_dot_product_flash_attention_backward,
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_a
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double dropout_p, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
Expand Down
Loading
Loading
0