8000 [ROCm] testing: enable MEFF/FA unittests for gfx1100 (#148911) · pytorch/pytorch@2a011ca · GitHub
[go: up one dir, main page]

Skip to content

Commit 2a011ca

Browse files
xinyazhangpytorchmergebot
authored andcommitted
[ROCm] testing: enable MEFF/FA unittests for gfx1100 (#148911)
Include gfx1100, and optionally enable gfx1201/gfx950 according to env var TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL Pull Request resolved: #148911 Approved by: https://github.com/jeffdaily
1 parent 9d37b50 commit 2a011ca

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

torch/testing/_internal/common_cuda.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,32 +39,34 @@
3939
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR))
4040
IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9))
4141

42-
def CDNA2OrLater():
43-
if TEST_WITH_ROCM:
44-
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
45-
return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx942"})
46-
return False
47-
48-
def evaluate_gfx_arch_exact(matching_arch):
42+
def evaluate_gfx_arch_within(arch_list):
4943
if not torch.cuda.is_available():
5044
return False
5145
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
52-
arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
53-
return arch == matching_arch
46+
effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
47+
# gcnArchName can be complicated strings like gfx90a:sramecc+:xnack-
48+
# Hence the matching should be done reversely
49+
return any(arch in effective_arch for arch in arch_list)
5450

55-
GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-'))
56-
GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-'))
51+
def CDNA2OrLater():
52+
return evaluate_gfx_arch_within(["gfx90a", "gfx942"])
5753

5854
def evaluate_platform_supports_flash_attention():
5955
if TEST_WITH_ROCM:
60-
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
56+
arch_list = ["gfx90a", "gfx942", "gfx1100"]
57+
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
58+
arch_list += ["gfx1201", "gfx950"]
59+
return evaluate_gfx_arch_within(arch_list)
6160
if TEST_CUDA:
6261
return not IS_WINDOWS and SM80OrLater
6362
return False
6463

6564
def evaluate_platform_supports_efficient_attention():
6665
if TEST_WITH_ROCM:
67-
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
66+
arch_list = ["gfx90a", "gfx942", "gfx1100"]
67+
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
68+
arch_list += ["gfx1201", "gfx950"]
69+
return evaluate_gfx_arch_within(arch_list)
6870
if TEST_CUDA:
6971
return True
7072
return False

0 commit comments

Comments
 (0)
0