8000 Triton Kernel Rejects NamedTupleVariable Arguments · Issue #148289 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Triton Kernel Rejects NamedTupleVariable Arguments #148289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
8000 cora-codes opened this issue Mar 2, 2025 · 10 comments
Open

Triton Kernel Rejects NamedTupleVariable Arguments #148289

cora-codes opened this issue Mar 2, 2025 · 10 comments
Labels
dynamo-triage-jan2025 module: dynamo module: fx module: user triton related to ability to directly torch.compile triton kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@cora-codes
Copy link
cora-codes commented Mar 2, 2025

🚀 The feature, motivation and pitch

PyTorch's TorchDynamo fails when passing NamedTupleVariable to Triton kernels, raising "Unexpected argument type for a Triton kernel". It would be nice to support named tuple arguments since it makes writing Triton kernels far cleaner.

import torch
import typing
import triton
from torch.profiler import profile, record_function, ProfilerActivity

class T1(typing.NamedTuple):
    foo: None = None
    bar: None = None
class T2(typing.NamedTuple):
    foo: T1 = T1()
    bar: T1 = T1()
class T3(typing.NamedTuple):
    foo: T2 = T2()
    bar: T2 = T2()
class T4(typing.NamedTuple):
    foo: T3 = T3()
    bar: T3 = T3()
class T5(typing.NamedTuple):
    foo: T4 = T4()
    bar: T4 = T4()

@triton.jit
def test(t5: T5):
    pass

if __name__ == "__main__":
    t5 = T5()

    @torch.compile(mode="max-autotune-no-cudagraphs", fullgraph=True)
    def main():
        for i in range(100):
            test[(1,)](t5)
    main()

Alternatives

No response

Additional context

No response

cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @chauhang @penguinwu @voznesenskym @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @amjames @oulgen @aakhundov @davidberard98

@cora-codes
Copy link
Author

Looks like schemas cannot support tuples which is pretty frustrating restriction here:

f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. "

@janeyx99 janeyx99 added oncall: pt2 module: user triton related to ability to directly torch.compile triton kernels labels Mar 3, 2025
@oulgen
Copy link
Contributor
oulgen commented Mar 3, 2025

@cora-codes does triton support named tuple arguments without torch.compile? It did not used to, not sure whether this changed.

@zou3519 ^

@cora-codes
Copy link
Author

Yes, Triton supports named tuple arguments now. See: triton-lang/triton#5518

@oulgen
Copy link
Contributor
oulgen commented Mar 3, 2025

Yes, Triton supports named tuple arguments now. See: triton-lang/triton#5518

Yep, this is very new and is in fact newer than the PyTorch compatible Triton version: https://github.com/pytorch/pytorch/blob/main/.ci/docker/triton_version.txt

@cora-codes
Copy link
Author

Understood. This is why I'm asking for it as a feature rather than saying it should be treated like a bug 😅 . I think this addition makes Triton kernels so much cleaner - you can pass in strides, shapes far easier than before etc

@zou3519
Copy link
Contributor
zou3519 commented Mar 3, 2025

named tuples in Dynamo/fx are a little wild (cc @StrongerXi) but the request has been heard. @StrongerXi we should bring back the namedtuple PR

@StrongerXi
Copy link
Contributor

@cora-codes could you try and see if #147145 fixes the issue.

@cora-codes
Copy link
Author

@StrongerXi I tried it with #147145 (2.7.0a0+git7be4215) and it fails with the same exception.

@desertfire desertfire added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 4, 2025
@zou3519
Copy link
Contributor
zou3519 commented Mar 5, 2025

cc @davidberard98 too

@cora-codes
Copy link
Author

Would dataclasses be easier to support than namedtuples?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-triage-jan2025 module: dynamo module: fx module: user triton related to ability to directly torch.compile triton kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants
0