File tree Expand file tree Collapse file tree 2 files changed +8
-3
lines changed
lmdeploy/pytorch/kernels/cuda Expand file tree Collapse file tree 2 files changed +8
-3
lines changed Original file line number Diff line number Diff line change 12
12
13
13
assert triton .__version__ >= '2.1.0'
14
14
15
- LOG2 = math .log (2 )
15
+ LOG2 : tl . constexpr = math .log (2 )
16
16
17
17
18
18
@triton .jit
Original file line number Diff line number Diff line change @@ -621,6 +621,7 @@ def convert_pv(p, v):
621
621
622
622
623
623
_convert_pv = None
624
+ _nv_cap = None
624
625
625
626
626
627
# TODO: how to support inplace autotune?
@@ -1099,9 +1100,10 @@ def paged_attention_fwd(
1099
1100
max_seqlen (int): The max input length.
1100
1101
BLOCK (int): The kernel block size.
1101
1102
"""
1102
- global _convert_pv
1103
+ global _convert_pv , _nv_cap
1103
1104
if _convert_pv is None :
1104
1105
nv_cap = torch .cuda .get_device_capability ()
1106
+ _nv_cap = nv_cap
1105
1107
_convert_pv = _get_convert_pv (nv_cap )
1106
1108
1107
1109
if kv_layout == 'bshd' :
@@ -1150,7 +1152,10 @@ def _get_block_d(Lk):
1150
1152
is_decoding = q .shape [- 3 ] == q_seqlens .size (0 )
1151
1153
if not is_decoding :
1152
1154
BLOCK_DMODEL , BLOCK_DMODEL1 , BLOCK_DV = _get_block_d (Lq )
1153
- BLOCK_M = max (16 , min (BLOCK , 16384 // BLOCK_DMODEL ))
1155
+ if _nv_cap [0 ] < 8 :
1156
+ BLOCK_M = max (16 , min (BLOCK , 8192 // BLOCK_DMODEL ))
1157
+ else :
1158
+ BLOCK_M = max (16 , min (BLOCK , 16384 // BLOCK_DMODEL ))
1154
1159
num_warps = 4
1155
1160
num_stages = 2
1156
1161
kv_head = k .shape [h_dim ]
You can’t perform that action at this time.
0 commit comments