8000 ROCm SDPA: Ensure attn_mask has the same dtype with q · pytorch/pytorch@60b555e · GitHub
[go: up one dir, main page]

Skip to content

Commit 60b555e

Browse files
committed
ROCm SDPA: Ensure attn_mask has the same dtype with q
This is required by current AOTriton's backend.
1 parent 4e0de50 commit 60b555e

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,12 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
705705
}
706706

707707
#ifdef USE_ROCM
708+
if (params.attn_mask.has_value()) {
709+
if (params.attn_mask.value().dtype() != params.query.dtype()) {
710+
TORCH_WARN("Efficient attention on ROCM requires attn_mask has the same datatype as of q,k,v");
711+
return false;
712+
}
713+
}
708714
return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
709715
#else
710716
auto dprop = at::cuda::getCurrentDeviceProperties();

0 commit comments

Comments
 (0)
0