8000 small block_m for sm7.x (#2626) · InferenceNexus/lmdeploy@a50555b · GitHub
[go: up one dir, main page]

Skip to content

Commit a50555b

Browse files
authored
small block_m for sm7.x (InternLM#2626)
* small block_m for sm7.x * fix alibi
1 parent 1bb7a9e commit a50555b

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
assert triton.__version__ >= '2.1.0'
1414

15-
LOG2 = math.log(2)
15+
LOG2: tl.constexpr = math.log(2)
1616

1717

1818
@triton.jit

lmdeploy/pytorch/kernels/cuda/pagedattention.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ def convert_pv(p, v):
621621

622622

623623
_convert_pv = None
624+
_nv_cap = None
624625

625626

626627
# TODO: how to support inplace autotune?
@@ -1099,9 +1100,10 @@ def paged_attention_fwd(
10991100
max_seqlen (int): The max input length.
11001101
BLOCK (int): The kernel block size.
11011102
"""
1102-
global _convert_pv
1103+
global _convert_pv, _nv_cap
11031104
if _convert_pv is None:
11041105
nv_cap = torch.cuda.get_device_capability()
1106+
_nv_cap = nv_cap
11051107
_convert_pv = _get_convert_pv(nv_cap)
11061108

11071109
if kv_layout == 'bshd':
@@ -1150,7 +1152,10 @@ def _get_block_d(Lk):
11501152
is_decoding = q.shape[-3] == q_seqlens.size(0)
11511153
if not is_decoding:
11521154
BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq)
1153-
BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL))
1155+
if _nv_cap[0] < 8:
1156+
BLOCK_M = max(16, min(BLOCK, 8192 // BLOCK_DMODEL))
1157+
else:
1158+
BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL))
11541159
num_warps = 4
11551160
num_stages = 2
11561161
kv_head = k.shape[h_dim]

0 commit comments

Comments
 (0)
0