8000 #ifdef -> #if + fix check -inf · Pints-AI/llama.cpp@19e0b8e · GitHub
[go: up one dir, main page]

Skip to content

Commit 19e0b8e

Browse files
committed
#ifdef -> #if + fix check -inf
1 parent 8d7a606 commit 19e0b8e

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ if (LLAMA_CUBLAS)
371371
#set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work
372372
endif()
373373
endif()
374+
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-lineinfo")
374375
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
375376

376377
else()
@@ -729,7 +730,7 @@ endif()
729730
set(CUDA_CXX_FLAGS "")
730731
731732
if (LLAMA_CUBLAS)
732-
set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math)
733+
set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math -lineinfo)
733734
if (NOT MSVC)
734735
list(APPEND CUDA_FLAGS -Wno-pedantic)
735736
endif()

ggml-cuda.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
662662
}
663663

664664
static __device__ __forceinline__ half warp_reduce_sum(half x) {
665-
#ifdef __CUDA_ARCH__ >= CC_VOLTA
665+
#if __CUDA_ARCH__ >= CC_VOLTA
666666
#pragma unroll
667667
for (int mask = 16; mask > 0; mask >>= 1) {
668668
x = __hadd(__shfl_xor_sync(0xffffffff, x, mask, 32), x);
@@ -6601,8 +6601,8 @@ static __global__ void flash_attn_ext_f16(
66016601
smax = warp_reduce_max(__hmax(smax, s));
66026602
M[j] = warp_reduce_max(__hmax(M[j], s));
66036603

6604-
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
6605-
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
6604+
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
6605+
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
66066606

66076607
S[j] = S[j]*ms + warp_reduce_sum(vs);
66086608

@@ -6628,7 +6628,7 @@ static __global__ void flash_attn_ext_f16(
66286628
smax = warp_reduce_max(smax);
66296629
M[j] = warp_reduce_max(M[j]);
66306630

6631-
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
6631+
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
66326632

66336633
// create a QxQ diagonal matrix for rescaling the output
66346634
if (lane_id == j) {
@@ -6641,7 +6641,7 @@ static __global__ void flash_attn_ext_f16(
66416641
for (int64_t p = lane_id; p < C; p += NW) {
66426642
const half s = ss[j*T + p];
66436643

6644-
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
6644+
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
66456645

66466646
ls += vs;
66476647

@@ -6654,7 +6654,7 @@ static __global__ void flash_attn_ext_f16(
66546654
}
66556655

66566656
// skip -INF blocks
6657-
if (__hisinf(smax)) {
6657+
if (__hisinf(smax) == -1) {
66586658
continue;
66596659
}
66606660

@@ -6740,8 +6740,8 @@ static __global__ void flash_attn_ext_f16(
67406740

67416741
M = __hmax(M0, M1);
67426742

6743-
const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M);
6744-
const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M);
6743+
const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M);
6744+
const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M);
67456745

67466746
S = S0*ms0 + S1*ms1;
67476747

0 commit comments

Comments
 (0)
0