8000 ROCm SDPA: Ensure attn_mask has the same dtype with q (#144398) · pytorch/pytorch@a99cc48 · GitHub
[go: up one dir, main page]

Skip to content

Commit a99cc48

Browse files
ROCm SDPA: Ensure attn_mask has the same dtype with q (#144398)
ROCm SDPA: Ensure attn_mask has the same dtype with q (#143242) This is required by current AOTriton's backend. Fixes NaN when calling SDPA ME backend with `q.dtype() != attn_mask.dtype()` when training llama2 using transformers+deepspeed+pytorch Corresponding CUDA check seems to be here: https://github.com/pytorch/pytorch/blob/708ce3c0082d670d9eaff84bc3c43cad4554a75d/aten/src/ATen/native/transformers/cuda/attention.cu#L1331-L1336 Pull Request resolved: #143242 Approved by: https://github.com/jeffdaily (cherry picked from commit 3068ce0) Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com>
1 parent 4d9de27 commit a99cc48

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,14 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
705705
}
706706

707707
#ifdef USE_ROCM
708+
if (params.attn_mask.has_value()) {
709+
const auto q_dtype = params.query.dtype();
710+
const auto bias_dtype = params.attn_mask.value().dtype();
711+
if (bias_dtype != at::kBool && bias_dtype != q_dtype) {
712+
TORCH_WARN("Efficient attention on ROCM requires attn_mask be boolean, or has the same datatype as of q,k,v");
713+
return false;
714+
}
715+
}
708716
return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
709717
#else
710718
auto dprop = at::cuda::getCurrentDeviceProperties();

0 commit comments

Comments
 (0)
0