10000 Allows boolean attn_mask. · pytorch/pytorch@1489d35 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1489d35

Browse files
committed
Allows boolean attn_mask.
1 parent 857b9c3 commit 1489d35

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -708,8 +708,10 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
708708

709709
#ifdef USE_ROCM
710710
if (params.attn_mask.has_value()) {
711-
if (params.attn_mask.value().dtype() != params.query.dtype()) {
712-
TORCH_WARN("Efficient attention on ROCM requires attn_mask has the same datatype as of q,k,v");
711+
const auto q_dtype = params.query.dtype();
712+
const auto bias_dtype = params.attn_mask.value().dtype();
713+
if (bias_dtype != at::kBool && bias_dtype != q_dtype) {
714+
TORCH_WARN("Efficient attention on ROCM requires attn_mask be boolean, or has the same datatype as of q,k,v");
713715
return false;
714716
}
715717
}

0 commit comments

Comments
 (0)
0