@@ -662,7 +662,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
662
662
}
663
663
664
664
static __device__ __forceinline__ half warp_reduce_sum (half x) {
665
- #ifdef __CUDA_ARCH__ >= CC_VOLTA
665
+ #if __CUDA_ARCH__ >= CC_VOLTA
666
666
#pragma unroll
667
667
for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
668
668
x = __hadd (__shfl_xor_sync (0xffffffff , x, mask, 32 ), x);
@@ -6601,8 +6601,8 @@ static __global__ void flash_attn_ext_f16(
6601
6601
smax = warp_reduce_max (__hmax (smax, s));
6602
6602
M[j] = warp_reduce_max (__hmax (M[j], s));
6603
6603
6604
- const half ms = __hisinf (m) ? __float2half (0 .0f ) : hexp (m - M[j]);
6605
- const half vs = __hisinf (s) ? __float2half (0 .0f ) : hexp (s - M[j]);
6604
+ const half ms = __hisinf (m) == - 1 ? __float2half (0 .0f ) : hexp (m - M[j]);
6605
+ const half vs = __hisinf (s) == - 1 ? __float2half (0 .0f ) : hexp (s - M[j]);
6606
6606
6607
6607
S[j] = S[j]*ms + warp_reduce_sum (vs);
6608
6608
@@ -6628,7 +6628,7 @@ static __global__ void flash_attn_ext_f16(
6628
6628
smax = warp_reduce_max (smax);
6629
6629
M[j] = warp_reduce_max (M[j]);
6630
6630
6631
- const half ms = __hisinf (m) ? __float2half (0 .0f ) : hexp (m - M[j]);
6631
+ const half ms = __hisinf (m) == - 1 ? __float2half (0 .0f ) : hexp (m - M[j]);
6632
6632
6633
6633
// create a QxQ diagonal matrix for rescaling the output
6634
6634
if (lane_id == j) {
@@ -6641,7 +6641,7 @@ static __global__ void flash_attn_ext_f16(
6641
6641
for (int64_t p = lane_id; p < C; p += NW) {
6642
6642
const half s = ss[j*T + p];
6643
6643
6644
- const half vs = __hisinf (s) ? __float2half (0 .0f ) : hexp (s - M[j]);
6644
+ const half vs = __hisinf (s) == - 1 ? __float2half (0 .0f ) : hexp (s - M[j]);
6645
6645
6646
6646
ls += vs;
6647
6647
@@ -6654,7 +6654,7 @@ static __global__ void flash_attn_ext_f16(
6654
6654
}
6655
6655
6656
6656
// skip -INF blocks
6657
- if (__hisinf (smax)) {
6657
+ if (__hisinf (smax) == - 1 ) {
6658
6658
continue ;
6659
6659
}
6660
6660
@@ -6740,8 +6740,8 @@ static __global__ void flash_attn_ext_f16(
6740
6740
6741
6741
M = __hmax (M0, M1);
6742
6742
6743
- const half ms0 = __hisinf (M0) ? __float2half (0 .0f ) : hexp (M0 - M);
6744
- const half ms1 = __hisinf (M1) ? __float2half (0 .0f ) : hexp (M1 - M);
6743
+ const half ms0 = __hisinf (M0) == - 1 ? __float2half (0 .0f ) : hexp (M0 - M);
6744
+ const half ms1 = __hisinf (M1) == - 1 ? __float2half (0 .0f ) : hexp (M1 - M);
6745
6745
6746
6746
S = S0*ms0 + S1*ms1;
6747
6747
0 commit comments