|
20 | 20 | import textwrap
|
21 | 21 | import time
|
22 | 22 | import unittest
|
| 23 | +from dataclasses import fields |
23 | 24 | from datetime import datetime
|
24 | 25 | from io import StringIO
|
25 | 26 | from typing import (
|
@@ -634,11 +635,43 @@ def has_incompatible_cudagraph_ops(gm):
|
634 | 635 | return False
|
635 | 636 |
|
636 | 637 |
|
| 638 | +# Attempt to import AttrsDescriptor from Triton |
637 | 639 | try:
|
638 |
| - from triton.compiler.compiler import AttrsDescriptor as instance_descriptor |
| 640 | + from triton.compiler.compiler import AttrsDescriptor |
| 641 | + |
| 642 | + attrs_descriptor_available = True |
| 643 | + # Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor |
| 644 | + ids_of_folded_args_available = "ids_of_folded_args" in [ |
| 645 | + f.name for f in fields(AttrsDescriptor) |
| 646 | + ] |
639 | 647 | except ImportError:
|
640 |
| - # To support older version of triton which does not have AttrsDescriptor |
641 |
| - # class |
| 648 | + attrs_descriptor_available = False |
| 649 | + |
| 650 | +# Define `instance_descriptor` function with clear conditional handling |
| 651 | +if attrs_descriptor_available: |
| 652 | + |
| 653 | + def instance_descriptor( |
| 654 | + divisible_by_16=None, |
| 655 | + equal_to_1=None, |
| 656 | + ids_of_folded_args=None, |
| 657 | + divisible_by_8=None, |
| 658 | + ): |
| 659 | + # Prepare the arguments for AttrsDescriptor |
| 660 | + kwargs = { |
| 661 | + "divisible_by_16": divisible_by_16, |
| 662 | + "equal_to_1": equal_to_1, |
| 663 | + "divisible_by_8": divisible_by_8, |
| 664 | + } |
| 665 | + |
| 666 | + # Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor |
| 667 | + if ids_of_folded_args_available: |
| 668 | + kwargs["ids_of_folded_args"] = ids_of_folded_args |
| 669 | + |
| 670 | + # Instantiate AttrsDescriptor with the prepared arguments |
| 671 | + return AttrsDescriptor(**kwargs) |
| 672 | + |
| 673 | +else: |
| 674 | + # Define a namedtuple as a fallback when AttrsDescriptor is not available |
642 | 675 | instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
|
643 | 676 | "instance_descriptor",
|
644 | 677 | ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
|
|
0 commit comments