8000 [Intel GPU] Enable SDPA on XPU (#147614) · pytorch/pytorch@c21dc11 · GitHub
[go: up one dir, main page]

Skip to content

Commit c21dc11

Browse files
DDElepytorchmergebot
authored andcommitted
[Intel GPU] Enable SDPA on XPU (#147614)
Motivation === This PR is part of the plan of OneDNN Upstreaming, as #114848 [(comment)](#114848 (comment)) stated. The support of SDPA is via the overridable variance on XPU backend. Beside the added `Attention.cpp` file, `Graph.h` is added to hold utils for OneDNN graph including those for kernel/compile graph caching. In addition, a selection of testcases in `test/test_transformers.py` are copied into the new `test/xpu/test_transformers.py` and modified accordingly to provide additional tests beyond `./third_party/torch-xpu-ops/test/xpu/test_ops_xpu.py`. Depends on OneDNN version v3.7 upgrade in #147498 Depends on BUILD_GRAPH switch in #147608 Pull Request resolved: #147614 Approved by: https://github.com/jansel, https://github.com/EikanWang
1 parent b17f522 commit c21dc11

File tree

10 files changed

+235
-0
lines changed

10 files changed

+235
-0
lines changed

aten/src/ATen/native/mkldnn/xpu/Attention.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,4 +227,6 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
227227
philox_offset,
228228
debug_attn_mask);
229229
}
230+
231+
REGISTER_XPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_xpu);
230232
} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14841,6 +14841,7 @@
1484114841
Meta: _fused_sdp_choice_meta
1484214842
CPU, NestedTensorCPU: _fused_sdp_choice_cpp
1484314843
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
14844+
XPU: _fused_sdp_choice_xpu
1484414845
tags: nondeterministic_seeded
1484514846

1484614847
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor)
@@ -14866,6 +14867,7 @@
1486614867
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
1486714868
dispatch:
1486814869
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable
14870+
XPU: _scaled_dot_product_fused_attention_overrideable_xpu
1486914871
tags: nondeterministic_seeded
1487014872

1487114873
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)

test/test_transformers.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3838,6 +3838,176 @@ def rand_nt(sequence_list, num_heads, head_dim):
38383838
}
38393839
)
38403840

3841+
class TestSDPAXpuOnly(NNTestCase):
3842+
""" Used to test XPU only functionality of scaled_dot_product_attention
3843+
Mostly migrate from TestSDPACudaOnly in test/test_transformers.py
3844+
3845+
Note that as SDPBackend.OVERRIDEABLE is not managed by sdpa_kernel so that
3846+
math ref has to be called explicitly via torch.ops.aten._scaled_dot_product_attention_math.
3847+
"""
3848+
3849+
@parametrize("type", ["dense"])
3850+
@parametrize("dropout", [0.0, 0.7])
3851+
@parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.half])
3852+
@skipIfTorchDynamo()
3853+
def test_fused_sdp_choice_xpu(self, device, type: str, dropout: float, dtype: torch.dtype):
3854+
# Migrate from test_fused_sdp_choice_cpu
3855+
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype)
3856+
size = SdpaShape(2, 8, 128, 64)
3857+
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
3858+
if dropout > 0.0 or dtype not in [torch.float32, torch.bfloat16, torch.float16]:
3859+
assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.MATH.value
3860+
else:
3861+
assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.OVERRIDEABLE.value
3862+
3863+
def test_fused_attention_different_dk_dv(self, device):
3864+
dtype = torch.bfloat16
3865+
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
3866+
batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64
3867+
q_shape = SdpaShape(batch, num_heads, 1, head_dim_k)
3868+
k_shape = SdpaShape(batch, num_heads, 2, head_dim_k)
3869+
v_shape = SdpaShape(batch, num_heads, 2, head_dim_v)
3870+
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
3871+
3872+
# test that we do not dispatch to onednn for an unsupported case
3873+
actual = F.scaled_dot_product_attention(
3874+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3875+
3876+
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
3877+
query.float(), key.float(), value.float(), attn_mask=None, dropout_p=0.0, is_causal=False)[0]
3878+
3879+
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
3880+
3881+
def test_onednn_attention_fail_d256(self, device):
3882+
# Test that onednn graph attention dispatching correctly bails out on d > 256
3883+
b, h = 1, 2
3884+
s_q, s_kv = 128, 128
3885+
d_qk, d_v = 512, 512
3886+
3887+
q = torch.randn(b, h, s_q, d_qk, device=device, dtype=torch.bfloat16)
3888+
k = torch.randn(b, h, s_kv, d_qk, device=device, dtype=torch.bfloat16)
3889+
v = torch.randn(b, h, s_kv, d_v, device=device, dtype=torch.bfloat16)
3890+
3891+
with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]):
3892+
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
3893+
_ = F.scaled_dot_product_attention(q, k, v)
3894+
3895+
@parametrize("type", ["dense"])
3896+
@parametrize("is_contiguous", [True, False])
3897+
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
3898+
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)
3899+
3900+
batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
3901+
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
3902+
3903+
# Test Packed
3904+
qkv = make_tensor(shape)
3905+
query, key, value = qkv.chunk(3, dim=-1)
3906+
3907+
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
3908+
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
3909+
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
3910+
3911+
if is_contiguous:
3912+
query = query.contiguous()
3913+
key = key.contiguous()
3914+
value = value.contiguous()
3915+
3916+
with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]):
3917+
actual = torch.nn.functional.scaled_dot_product_attention(
3918+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3919+
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
3920+
query.contiguous(), key.contiguous(), value.contiguous(), attn_mask=None, dropout_p=0.0, is_causal=False)[0]
3921+
3922+
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
3923+
3924+
@parametrize("fused_kernel", [SDPBackend.MATH, SDPBackend.OVERRIDEABLE])
3925+
@parametrize("dtype", [torch.half, torch.bfloat16, torch.float32])
3926+
@parametrize("batch_size,n_head,q_size,kv_size,head_dim", [
3927+
(2, 5, 9216, 9216, 64),
3928+
(2, 5, 9216, 77, 64),
3929+
(2, 10, 2304, 2304, 64),
3930+
(2, 10, 2304, 77, 64),
3931+
(2, 20, 576, 576, 64),
3932+
(2, 20, 576, 77, 64),
3933+
(2, 20, 144, 144, 64),
3934+
(2, 20, 144, 77, 64),
3935+
(1, 32, 1, 32, 128),
3936+
(4, 32, 1, 32, 128),
3937+
(1, 32, 32, 32, 128),
3938+
(4, 32, 32, 32, 128),
3939+
(1, 32, 2016, 2016, 128),
3940+
(4, 32, 2016, 2016, 128),
3941+
])
3942+
@parametrize("mask_type", ["float", "causal"])
3943+
@parametrize("train", [False])
3944+
def test_scaled_dot_product_fused_attention_mask_vs_math(
3945+
self,
3946+
device,
3947+
fused_kernel,
3948+
dtype,
3949+
batch_size,
3950+
q_size,
3951+
kv_size,
3952+
n_head,
3953+
head_dim,
3954+
mask_type,
3955+
train,
3956+
):
3957+
# Migrate from TestSDPACpuOnly
3958+
tol = Tolerances(1e-5, 5e-6)
3959+
if dtype is torch.bfloat16:
3960+
tol = Tolerances(5e-2, 5e-2)
3961+
if dtype is torch.float16:
3962+
tol = Tolerances(1e-2, 1e-2)
3963+
mask_shape = [batch_size, 1, 1, kv_size]
3964+
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
3965+
q_shape = SdpaShape(batch_size, n_head, q_size, head_dim)
3966+
kv_shape = SdpaShape(batch_size, n_head, kv_size, head_dim)
3967+
q = make_tensor(q_shape)
3968+
k = make_tensor(kv_shape)
3969+
v = make_tensor(kv_shape)
3970+
q2, k2, v2 = q.clone(), k.clone(), v.clone()
3971+
3972+
if train:
3973+
q.requires_grad_(True)
3974+
k.requires_grad_(True)
3975+
v.requires_grad_(True)
3976+
q2.requires_grad_(True)
3977+
k2.requires_grad_(True)
3978+
v2.requires_grad_(True)
3979+
3980+
# (B, nh, T, hs)
3981+
q = q.view(batch_size, q_size, n_head, head_dim).transpose(1, 2)
3982+
k = k.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
3983+
v = v.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
3984+
attn_mask = None
3985+
is_causal = False
3986+
if mask_type == "bool":
3987+
attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
3988+
elif mask_type == "float":
3989+
attn_mask = torch.randn(mask_shape, dtype=dtype, device=device)
3990+
elif mask_type == "causal":
3991+
is_causal = True
3992+
3993+
q2, k2, v2 = q2.float(), k2.float(), v2.float()
3994+
q2 = q2.view(batch_size, q_size, n_head, head_dim).transpose(1, 2)
3995+
k2 = k2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
3996+
v2 = v2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
3997+
attn_mask2 = attn_mask.float() if attn_mask is not None else None
3998+
3999+
if fused_kernel == SDPBackend.MATH:
4000+
actual = torch.ops.aten._scaled_dot_product_attention_math(
4001+
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)[0]
4002+
elif fused_kernel == SDPBackend.OVERRIDEABLE:
4003+
actual = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
4004+
q, k, v, attn_bias=attn_mask, dropout_p=0.0, is_causal=is_causal)[0]
4005+
4006+
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
4007+
q2, k2, v2, attn_mask=attn_mask2, dropout_p=0.0, is_causal=is_causal)[0]
4008+
4009+
self.assertEqual(actual.float(), math_ref, atol=tol.atol, rtol=tol.rtol)
4010+
38414011

38424012
class TestAttnBias(NNTestCase):
38434013

@@ -4014,6 +4184,7 @@ def test_is_causal_and_mask_fails(self, device):
40144184
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
40154185
instantiate_device_type_tests(TestSDPACpuOnly, globals(), only_for=("cpu"))
40164186
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
4187+
instantiate_device_type_tests(TestSDPAXpuOnly, globals(), only_for="xpu", allow_xpu=True)
40174188

40184189
if __name__ == '__main__':
40194190
run_tests()

torch/_inductor/lowering.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2779,6 +2779,16 @@ def is_aligned(x):
27792779
sdpa_constraint,
27802780
warn=False,
27812781
)
2782+
make_fallback(
2783+
aten._scaled_dot_product_fused_attention_overrideable.default,
2784+
sdpa_constraint,
2785+
warn=False,
2786+
)
2787+
make_fallback(
2788+
aten._scaled_dot_product_fused_attention_overrideable_backward.default,
2789+
sdpa_constraint,
2790+
warn=False,
2791+
)
27822792
make_fallback(aten._flash_attention_forward.default, sdpa_constraint)
27832793
make_fallback(aten._flash_attention_backward.default, sdpa_constraint)
27842794
make_fallback(aten._efficient_attention_forward.default, sdpa_constraint)

torch/_meta_registrations.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5607,6 +5607,47 @@ def meta__scaled_dot_product_cudnn_attention(
56075607
)
56085608

56095609

5610+
@register_meta([aten._scaled_dot_product_fused_attention_overrideable])
5611+
def meta__scaled_dot_product_fused_attention_overrideable(
5612+
query: Tensor,
5613+
key: Tensor,
5614+
value: Tensor,
5615+
attn_bias: Optional[Tensor] = None,
5616+
dropout_p: float = 0.0,
5617+
is_causal: bool = False,
5618+
return_debug_mask: bool = False,
5619+
scale: Optional[float] = None,
5620+
):
5621+
B = query.size(0)
5622+
H = query.size(1)
5623+
S_Q = query.size(2)
5624+
S_KV = key.size(2)
5625+
D_V = value.size(-1)
5626+
5627+
res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device)
5628+
logsum_exp = torch.empty(
5629+
(B, H, S_Q),
5630+
dtype=torch.float,
5631+
device=query.device,
5632+
)
5633+
5634+
# See Note [Seed and Offset]
5635+
seed = torch.empty((), dtype=torch.long, device="meta")
5636+
offset = torch.empty((), dtype=torch.long, device="meta")
5637+
5638+
return (
5639+
res,
5640+
logsum_exp,
5641+
None,
5642+
None,
5643+
S_Q,
5644+
S_KV,
5645+
seed,
5646+
offset,
5647+
None,
5648+
)
5649+
5650+
56105651
@register_meta(
56115652
[
56125653
aten._scaled_dot_product_flash_attention_backward,

torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle
3535
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);
3636
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
3737
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
38+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
39+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
3840
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
3941
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
4042
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);

torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_a
4141
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double dropout_p, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
4242
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
4343
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
44+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
45+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
4446
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
4547
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum);
4648
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);

0 commit comments

Comments
 (0)
0