8000 [ROCm] logsumexp on ROCm needs scaling back to natural base. (#156903) · pytorch/pytorch@823e223 · GitHub
[go: up one dir, main page]

Skip to content

Commit 823e223

Browse files
xinyazhangjeffdaily
authored andcommitted
[ROCm] logsumexp on ROCm needs scaling back to natural base. (#156903)
Fixes #156012 This is a temporary solution that makes context parallelism working before logsumexp behavior changes landed in AOTriton. After discussion we are not going to release AOTriton 0.10.1 to fix this due to * Even if the interface is not changed, changing the behavior of returned logsumexp tensor should still be considered as an ABI break. Such changes do not fall into the "ABI compatible" category and should be postponed to next release. * AOTriton 0.11 is scheduled to be released before end of July, which is less than five weeks Pull Request resolved: #156903 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
1 parent 6499420 commit 823e223

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

torch/distributed/tensor/experimental/_attention.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,25 @@ class _RotateMethod(Enum):
4444
logger = logging.getLogger(__name__)
4545

4646

47+
def _need_scaling() -> bool:
48+
if hasattr(torch.version, "hip") and torch.version.hip is not None:
49+
gcn_arch_name = torch.cuda.get_device_properties("cuda").gcnArchName
50+
_is_ck_supported = False
51+
for arch in ["gfx942", "gfx950"]:
52+
if arch in gcn_arch_name:
53+
_is_ck_supported = True
54+
# Check the function exists
55+
_preferred_rocm_fa_library = torch.backends.cuda.preferred_rocm_fa_library
56+
_CK_BACKEND = torch.backends.cuda._ROCmFABackends["ck"]
57+
# Note: it is possible that CK is selected but not compiled in the binary.
58+
if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND:
59+
# Unsure about CK's behavior, keep logsumexp untouched
60+
return False
61+
return True
62+
else:
63+
return False
64+
65+
4766
class _DispatchMode(Enum):
4867
MONKEY_PATCH = auto()
4968
TORCH_FUNCTION = auto()
@@ -446,6 +465,8 @@ def _templated_ring_attention(
446465
is_causal=is_causal_behavior.value,
447466
**kwargs,
448467
)
468+
if _need_scaling():
469+
logsumexp *= 0.6931471805599453
449470
sdpa_merger.step(out, logsumexp, partial)
450471

451472
return *sdpa_merger.results(), *rest

0 commit comments

Comments
 (0)
0