8000 add `v ` check (#3307) · InternLM/lmdeploy@6fe9371 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6fe9371

Browse files
authored
add v check (#3307)
1 parent 213faf2 commit 6fe9371

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

lmdeploy/pytorch/kernels/cuda/flashattention.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,9 @@
1616
assert TRITON_VERSION >= version.parse('2.2.0')
1717

1818
# TODO: fast op might not work on non-nv device
19-
if TRITON_VERSION >= VERSION_300:
20-
tanh = tl.extra.cuda.libdevice.tanh
21-
tl_log2 = tl.log2
22-
tl_exp2 = tl.exp2
23-
else:
24-
tanh = tl.math.tanh
25-
tl_log2 = tl.math.log2
26-
tl_exp2 = tl.math.exp2
19+
tanh = tl.extra.cuda.libdevice.tanh
20+
tl_log2 = tl.log2
21+
tl_exp2 = tl.exp2
2722

2823

2924
def _get_block_d(head_dim_k, head_dim_v):
@@ -48,9 +43,9 @@ def softcapping(qk, logit_softcapping: tl.constexpr):
4843

4944

5045
@triton.jit
51-
def _load_kv(ptrs, causal_mask: tl.constexpr, boundary_check: tl.constexpr):
46+
def _load_kv(ptrs, boundary_check: tl.constexpr):
5247
"""load kv."""
53-
if causal_mask:
48+
if boundary_check is not None:
5449
return tl.load(ptrs, boundary_check=boundary_check, padding_option='zero')
5550
else:
5651
return tl.load(ptrs)
@@ -59,7 +54,8 @@ def _load_kv(ptrs, causal_mask: tl.constexpr, boundary_check: tl.constexpr):
5954
@triton.jit
6055
def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, sm_scale, history_mask,
6156
kv_min_loc, causal_mask: tl.constexpr, window_size: tl.constexpr,
62-
logit_softcapping: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DK1: tl.constexpr):
57+
logit_softcapping: tl.constexpr, k_bound: tl.constexpr, v_bound: tl.constexpr,
58+
BLOCK_N: tl.constexpr, BLOCK_DK1: tl.constexpr):
6359
k_ptrs = tl.advance(k_ptrs, (0, loop_start))
6460
v_ptrs = tl.advance(v_ptrs, (loop_start, 0))
6561
if BLOCK_DK1:
@@ -69,11 +65,11 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start
6965
for start_n in range(loop_start, loop_end, BLOCK_N):
7066
start_n = tl.multiple_of(start_n, BLOCK_N)
7167

72-
k = _load_kv(k_ptrs, causal_mask, boundary_check=(1, ))
68+
k = _load_kv(k_ptrs, boundary_check=k_bound)
7369
qk = tl.dot(q, k)
7470

7571
if BLOCK_DK1 != 0:
76-
k1 = _load_kv(k1_ptrs, causal_mask, boundary_check=(1, ))
72+
k1 = _load_kv(k1_ptrs, boundary_check=k_bound)
7773
qk += tl.dot(q1, k1)
7874

7975
if causal_mask:
@@ -122,7 +118,7 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start
122118
acc = acc * alpha[:, None]
123119

124120
# update acc
125-
v = _load_kv(v_ptrs, causal_mask, boundary_check=(0, ))
121+
v = _load_kv(v_ptrs, boundary_check=v_bound)
126122
p = p.to(v.dtype)
127123
acc += tl.dot(p, v)
128124
# update m_i and l_i
@@ -175,8 +171,8 @@ def _flash_prefill_fwd_kernel(
175171
stride_oh: tl.constexpr,
176172
stride_od: tl.constexpr,
177173
kv_group_num,
178-
head_dim_k,
179-
head_dim_v,
174+
head_dim_k: tl.constexpr,
175+
head_dim_v: tl.constexpr,
180176
causal: tl.constexpr,
181177
window_size: tl.constexpr,
182178
logit_softcapping: tl.constexpr,
@@ -237,6 +233,15 @@ def _flash_prefill_fwd_kernel(
237233
order=(1, 0),
238234
)
239235

236+
k_bound0: tl.constexpr = None
237+
k_bound1: tl.constexpr = (1, )
238+
if head_dim_v == BLOCK_DV:
239+
v_bound0: tl.constexpr = None
240+
v_bound1: tl.constexpr = (0, )
241+
else:
242+
v_bound0: tl.constexpr = (1, )
243+
v_bound1: tl.constexpr = (0, 1)
244+
240245
if BLOCK_DK1 != 0:
241246
offs_dk1 = BLOCK_DK + tl.arange(0, BLOCK_DK1)
242247
mask_dk1 = offs_dk1 < head_dim_k
@@ -283,6 +288,8 @@ def _flash_prefill_fwd_kernel(
283288
causal_mask=False,
284289
window_size=window_size,
285290
logit_softcapping=logit_softcapping,
291+
k_bound=k_bound0,
292+
v_bound=v_bound0,
286293
BLOCK_N=BLOCK_N,
287294
BLOCK_DK1=BLOCK_DK1)
288295

@@ -307,6 +314,8 @@ def _flash_prefill_fwd_kernel(
307314
causal_mask=True,
308315
window_size=window_size,
309316
logit_softcapping=logit_softcapping,
317+
k_bound=k_bound1,
318+
v_bound=v_bound1,
310319
BLOCK_N=BLOCK_N,
311320
BLOCK_DK1=BLOCK_DK1)
312321
# epilogue
< 3101 div class="d-flex flex-items-center flex-justify-between">

0 commit comments

Comments
 (0)
0