8000 ROCm SDPA: Ensure attn_mask has the same dtype with q · pytorch/pytorch@857b9c3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 857b9c3

Browse files
xinyazhangpytorchmergebot
authored andcommitted
ROCm SDPA: Ensure attn_mask has the same dtype with q
This is required by current AOTriton's backend.
1 parent a1ae8fa commit 857b9c3

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,12 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
707707
}
708708

709709
#ifdef USE_ROCM
710+
if (params.attn_mask.has_value()) {
711+
if (params.attn_mask.value().dtype() != params.query.dtype()) {
712+
TORCH_WARN("Efficient attention on ROCM requires attn_mask has the same datatype as of q,k,v");
713+
return false;
714+
}
715+
}
710716
return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
711717
#else
712718
auto dprop = at::cuda::getCurrentDeviceProperties();

0 commit comments

Comments
 (0)
0