10000 Refactor instance_descriptor for new triton version (#119636) · pytorch/pytorch@957f376 · GitHub
[go: up one dir, main page]

Skip to content

Commit 957f376

Browse files
bhackpytorchmergebot
authored andcommitted
Refactor instance_descriptor for new triton version (#119636)
Check #119457 (comment) Pull Request resolved: #119636 Approved by: https://github.com/shunting314
1 parent 8464654 commit 957f376

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

torch/_inductor/codegen/triton_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List
1+
from typing import Any, Dict, List
22

33
import torch
44

@@ -59,7 +59,7 @@ def signature_to_meta(
5959
}
6060

6161

62-
def config_of(args: List[KernelArgType]) -> instance_descriptor:
62+
def config_of(args: List[KernelArgType]) -> Any:
6363
def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool:
6464
"""
6565
Roughly follow triton code here:

torch/_inductor/utils.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import textwrap
2121
import time
2222
import unittest
23+
from dataclasses import fields
2324
from datetime import datetime
2425
from io import StringIO
2526
from typing import (
@@ -634,11 +635,43 @@ def has_incompatible_cudagraph_ops(gm):
634635
return False
635636

636637

638+
# Attempt to import AttrsDescriptor from Triton
637639
try:
638-
from triton.compiler.compiler import AttrsDescriptor as instance_descriptor
640+
from triton.compiler.compiler import AttrsDescriptor
641+
642+
attrs_descriptor_available = True
643+
# Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor
644+
ids_of_folded_args_available = "ids_of_folded_args" in [
645+
f.name for f in fields(AttrsDescriptor)
646+
]
639647
except ImportError:
640-
# To support older version of triton which does not have AttrsDescriptor
641-
# class
648+
attrs_descriptor_available = False
649+
650+
# Define `instance_descriptor` function with clear conditional handling
651+
if attrs_descriptor_available:
652+
653+
def instance_descriptor(
654+
divisible_by_16=None,
655+
equal_to_1=None,
656+
ids_of_folded_args=None,
657+
divisible_by_8=None,
658+
):
659+
# Prepare the arguments for AttrsDescriptor
660+
kwargs = {
661+
"divisible_by_16": divisible_by_16,
662+
"equal_to_1": equal_to_1,
663+
"divisible_by_8": divisible_by_8,
664+
}
665+
666+
# Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor
667+
if ids_of_folded_args_available:
668+
kwargs["ids_of_folded_args"] = ids_of_folded_args
669+
670+
# Instantiate AttrsDescriptor with the prepared arguments
671+
return AttrsDescriptor(**kwargs)
672+
673+
else:
674+
# Define a namedtuple as a fallback when AttrsDescriptor is not available
642675
instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
643676
"instance_descriptor",
644677
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],

0 commit comments

Comments
 (0)
0