8000 [ascend] support paged_prefill_attn when batch > 1 (#2612) · InferenceNexus/lmdeploy@48dcd21 · GitHub
[go: up one dir, main page]

Skip to content

Commit 48dcd21

Browse files
authored
[ascend] support paged_prefill_attn when batch > 1 (InternLM#2612)
1 parent 77be205 commit 48dcd21

File tree

1 file changed

+48
-25
lines changed

1 file changed

+48
-25
lines changed

lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,50 +50,73 @@ def update_step_context(cls, step_context):
5050
device = step_context.block_offsets.device
5151

5252
is_unpaged_prefill = False
53-
q_start_loc_cpu = step_context.q_start_loc.cpu()
54-
q_seqlens_cpu = step_context.q_seqlens.cpu()
55-
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
56-
max_q_seq_len = torch.max(q_seqlens_cpu).item()
57-
max_kv_seq_len = torch.max(kv_seqlens_cpu).item()
58-
5953
if not step_context.is_decoding:
6054
is_unpaged_prefill = \
6155
all((step_context.q_seqlens ==
6256
step_context.kv_seqlens).tolist())
63-
if is_unpaged_prefill:
64-
single_attention_mask = torch.logical_not(
65-
torch.tril(
66-
torch.ones(max_q_seq_len,
67-
max_kv_seq_len,
68-
dtype=torch.bool).cuda(),
69-
diagonal=max_kv_seq_len - max_q_seq_len,
70-
))
71-
attention_mask.append(single_attention_mask)
57+
7258
total_slots = torch.arange(block_num * block_size,
7359
dtype=torch.long,
7460
device=device)
7561
total_slots = total_slots.view(block_num, block_size)
62+
63+
q_seqlens_list = step_context.q_seqlens.tolist()
64+
kv_seqlens_list = step_context.kv_seqlens.tolist()
65+
max_q_seq_len = max(q_seqlens_list)
66+
max_kv_seq_len = max(kv_seqlens_list)
67+
7668
for i in range(step_context.q_start_loc.size(0)):
77-
q_seq_len = int(step_context.q_seqlens[i])
78-
kv_seq_len = int(step_context.kv_seqlens[i])
69+
q_seq_len = q_seqlens_list[i]
70+
kv_seq_len = kv_seqlens_list[i]
71+
72+
# collect kv start indices.
73+
history_length = kv_seq_len - q_seq_len
74+
slot_tables = total_slots[step_context.block_offsets[i]].flatten()
75+
slot_indices = [p for p in range(history_length, kv_seq_len)]
76+
slots = slot_tables[slot_indices].reshape((-1, 1))
77+
kv_start_indices.append(slots)
78+
79+
# collect attention mask of paged_prefill attention stage.
7980
if not (step_context.is_decoding or is_unpaged_prefill):
8081
single_attention_mask = torch.logical_not(
8182
torch.tril(
82-
torch.ones(step_context.q_seqlens[i],
83+
torch.ones(q_seq_len,
8384
step_context.block_offsets.shape[1] *
8485
block_size,
8586
dtype=torch.bool).cuda(),
86-
diagonal=step_context.kv_seqlens[i] -
87-
step_context.q_seqlens[i],
87+
diagonal=kv_seq_len - q_seq_len,
8888
))
8989
attention_mask.append(single_attention_mask)
90-
history_length = kv_seq_len - q_seq_len
91-
slot_tables = total_slots[step_context.block_offsets[i]].flatten()
92-
slot_indices = [p for p in range(history_length, kv_seq_len)]
93-
slots = slot_tables[slot_indices].reshape((-1, 1))
94-
kv_start_indices.append(slots)
90+
9591
kv_start_indices = torch.cat(kv_start_indices)
9692

93+
if step_context.is_decoding:
94+
# prepare somae params of paged_decode attention stage.
95+
q_start_loc_cpu, q_seqlens_cpu = None, None
96+
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
97+
elif is_unpaged_prefill:
98+
# prepare somae params of unpaged_prefill attention stage.
99+
q_start_loc_cpu, kv_seqlens_cpu = None, None
100+
q_seqlens_cpu = step_context.q_seqlens.cpu()
101+
single_attention_mask = torch.logical_not(
102+
torch.tril(
103+
torch.ones(max_q_seq_len, max_kv_seq_len,
104+
dtype=torch.bool).cuda(),
105+
diagonal=max_kv_seq_len - max_q_seq_len,
106+
))
107+
attention_mask.append(single_attention_mask)
108+
else:
109+
# prepare somae params of paged_prefill attention stage.
110+
q_start_loc_cpu, q_seqlens_cpu = None, None
111+
kv_seqlens_cpu = step_context.kv_seqlens.repeat_interleave(
112+
step_context.q_seqlens, 0).cpu()
113+
block_offsets_int32 = step_context.block_offsets.to(torch.int32)
114+
step_context.block_offsets = block_offsets_int32.repeat_interleave(
115+
step_context.q_seqlens, 0)
116+
attention_mask = [
117+
torch.cat([mask for mask in attention_mask]).unsqueeze(1)
118+
]
119+
97120
attn_meta_cls = cls.get_attention_metadata_cls()
98121
attn_metadata = attn_meta_cls(
99122
step_context.is_decoding,

0 commit comments

Comments
 (0)
0