|
39 | 39 | IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR))
|
40 | 40 | IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9))
|
41 | 41 |
|
42 |
| -def CDNA2OrLater(): |
43 |
| - if TEST_WITH_ROCM: |
44 |
| - gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName |
45 |
| - return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx942"}) |
46 |
| - return False |
47 |
| - |
48 |
| -def evaluate_gfx_arch_exact(matching_arch): |
| 42 | +def evaluate_gfx_arch_within(arch_list): |
49 | 43 | if not torch.cuda.is_available():
|
50 | 44 | return False
|
51 | 45 | gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
|
52 |
| - arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) |
53 |
| - return arch == matching_arch |
| 46 | + effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) |
| 47 | + # gcnArchName can be complicated strings like gfx90a:sramecc+:xnack- |
| 48 | + # Hence the matching should be done reversely |
| 49 | + return any(arch in effective_arch for arch in arch_list) |
54 | 50 |
|
55 |
| -GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-')) |
56 |
| -GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')) |
| 51 | +def CDNA2OrLater(): |
| 52 | + return evaluate_gfx_arch_within(["gfx90a", "gfx942"]) |
57 | 53 |
|
58 | 54 | def evaluate_platform_supports_flash_attention():
|
59 | 55 | if TEST_WITH_ROCM:
|
60 |
| - return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') |
| 56 | + arch_list = ["gfx90a", "gfx942", "gfx1100"] |
| 57 | + if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": |
| 58 | + arch_list += ["gfx1201", "gfx950"] |
| 59 | + return evaluate_gfx_arch_within(arch_list) |
61 | 60 | if TEST_CUDA:
|
62 | 61 | return not IS_WINDOWS and SM80OrLater
|
63 | 62 | return False
|
64 | 63 |
|
65 | 64 | def evaluate_platform_supports_efficient_attention():
|
66 | 65 | if TEST_WITH_ROCM:
|
67 |
| - return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') |
| 66 | + arch_list = ["gfx90a", "gfx942", "gfx1100"] |
| 67 | + if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": |
| 68 | + arch_list += ["gfx1201", "gfx950"] |
| 69 | + return evaluate_gfx_arch_within(arch_list) |
68 | 70 | if TEST_CUDA:
|
69 | 71 | return True
|
70 | 72 | return False
|
|
0 commit comments