16
16
assert TRITON_VERSION >= version .parse ('2.2.0' )
17
17
18
18
# TODO: fast op might not work on non-nv device
19
- if TRITON_VERSION >= VERSION_300 :
20
- tanh = tl .extra .cuda .libdevice .tanh
21
- tl_log2 = tl .log2
22
- tl_exp2 = tl .exp2
23
- else :
24
- tanh = tl .math .tanh
25
- tl_log2 = tl .math .log2
26
- tl_exp2 = tl .math .exp2
19
+ tanh = tl .extra .cuda .libdevice .tanh
20
+ tl_log2 = tl .log2
21
+ tl_exp2 = tl .exp2
27
22
28
23
29
24
def _get_block_d (head_dim_k , head_dim_v ):
@@ -48,9 +43,9 @@ def softcapping(qk, logit_softcapping: tl.constexpr):
48
43
49
44
50
45
@triton .jit
51
- def _load_kv (ptrs , causal_mask : tl . constexpr , boundary_check : tl .constexpr ):
46
+ def _load_kv (ptrs , boundary_check : tl .constexpr ):
52
47
"""load kv."""
53
- if causal_mask :
48
+ if boundary_check is not None :
54
49
return tl .load (ptrs , boundary_check = boundary_check , padding_option = 'zero' )
55
50
else :
56
51
return tl .load (ptrs )
@@ -59,7 +54,8 @@ def _load_kv(ptrs, causal_mask: tl.constexpr, boundary_check: tl.constexpr):
59
54
@triton .jit
60
55
def _prefill_fwd_inner (acc , l_i , m_i , q , k_ptrs , v_ptrs , q1 , k1_ptrs , loop_start , loop_end , sm_scale , history_mask ,
61
56
kv_min_loc , causal_mask : tl .constexpr , window_size : tl .constexpr ,
62
- logit_softcapping : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_DK1 : tl .constexpr ):
57
+ logit_softcapping : tl .constexpr , k_bound : tl .constexpr , v_bound : tl .constexpr ,
58
+ BLOCK_N : tl .constexpr , BLOCK_DK1 : tl .constexpr ):
63
59
k_ptrs = tl .advance (k_ptrs , (0 , loop_start ))
64
60
v_ptrs = tl .advance (v_ptrs , (loop_start , 0 ))
65
61
if BLOCK_DK1 :
@@ -69,11 +65,11 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start
69
65
for start_n in range (loop_start , loop_end , BLOCK_N ):
70
66
start_n = tl .multiple_of (start_n , BLOCK_N )
71
67
72
- k = _load_kv (k_ptrs , causal_mask , boundary_check = ( 1 , ) )
68
+ k = _load_kv (k_ptrs , boundary_check = k_bound )
73
69
qk = tl .dot (q , k )
74
70
75
71
if BLOCK_DK1 != 0 :
76
- k1 = _load_kv (k1_ptrs , causal_mask , boundary_check = ( 1 , ) )
72
+ k1 = _load_kv (k1_ptrs , boundary_check = k_bound )
77
73
qk += tl .dot (q1 , k1 )
78
74
79
75
if causal_mask :
@@ -122,7 +118,7 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start
122
118
acc = acc * alpha [:, None ]
123
119
124
120
# update acc
125
- v = _load_kv (v_ptrs , causal_mask , boundary_check = ( 0 , ) )
121
+ v = _load_kv (v_ptrs , boundary_check = v_bound )
126
122
p = p .to (v .dtype )
127
123
acc += tl .dot (p , v )
128
124
# update m_i and l_i
@@ -175,8 +171,8 @@ def _flash_prefill_fwd_kernel(
175
171
stride_oh : tl .constexpr ,
176
172
stride_od : tl .constexpr ,
177
173
kv_group_num ,
178
- head_dim_k ,
179
- head_dim_v ,
174
+ head_dim_k : tl . constexpr ,
175
+ head_dim_v : tl . constexpr ,
180
176
causal : tl .constexpr ,
181
177
window_size : tl .constexpr ,
182
178
logit_softcapping : tl .constexpr ,
@@ -237,6 +233,15 @@ def _flash_prefill_fwd_kernel(
237
233
order = (1 , 0 ),
238
234
)
239
235
236
+ k_bound0 : tl .constexpr = None
237
+ k_bound1 : tl .constexpr = (1 , )
238
+ if head_dim_v == BLOCK_DV :
239
+ v_bound0 : tl .constexpr = None
240
+ v_bound1 : tl .constexpr = (0 , )
241
+ else :
242
+ v_bound0 : tl .constexpr = (1 , )
243
+ v_bound1 : tl .constexpr = (0 , 1 )
244
+
240
245
if BLOCK_DK1 != 0 :
241
246
offs_dk1 = BLOCK_DK + tl .arange (0 , BLOCK_DK1 )
242
247
mask_dk1 = offs_dk1 < head_dim_k
@@ -283,6 +288,8 @@ def _flash_prefill_fwd_kernel(
283
288
causal_mask = False ,
284
289
window_size = window_size ,
285
290
logit_softcapping = logit_softcapping ,
291
+ k_bound = k_bound0 ,
292
+ v_bound = v_bound0 ,
286
293
BLOCK_N = BLOCK_N ,
287
294
BLOCK_DK1 = BLOCK_DK1 )
288
295
@@ -307,6 +314,8 @@ def _flash_prefill_fwd_kernel(
307
314
causal_mask = True ,
308
315
window_size = window_size ,
309
316
logit_softcapping = logit_softcapping ,
317
+ k_bound = k_bound1 ,
318
+ v_bound = v_bound1 ,
310
319
BLOCK_N = BLOCK_N ,
311
320
BLOCK_DK1 = BLOCK_DK1 )
312
321
# epilogue
0 commit comments