@@ -50,50 +50,73 @@ def update_step_context(cls, step_context):
50
50
device = step_context .block_offsets .device
51
51
52
52
is_unpaged_prefill = False
53
- q_start_loc_cpu = step_context .q_start_loc .cpu ()
54
- q_seqlens_cpu = step_context .q_seqlens .cpu ()
55
- kv_seqlens_cpu = step_context .kv_seqlens .cpu ()
56
- max_q_seq_len = torch .max (q_seqlens_cpu ).item ()
57
- max_kv_seq_len = torch .max (kv_seqlens_cpu ).item ()
58
-
59
53
if not step_context .is_decoding :
60
54
is_unpaged_prefill = \
61
55
all ((step_context .q_seqlens ==
62
56
step_context .kv_seqlens ).tolist ())
63
- if is_unpaged_prefill :
64
- single_attention_mask = torch .logical_not (
65
- torch .tril (
66
- torch .ones (max_q_seq_len ,
67
- max_kv_seq_len ,
68
- dtype = torch .bool ).cuda (),
69
- diagonal = max_kv_seq_len - max_q_seq_len ,
70
- ))
71
- attention_mask .append (single_attention_mask )
57
+
72
58
total_slots = torch .arange (block_num * block_size ,
73
59
dtype = torch .long ,
74
60
device = device )
75
61
total_slots = total_slots .view (block_num , block_size )
62
+
63
+ q_seqlens_list = step_context .q_seqlens .tolist ()
64
+ kv_seqlens_list = step_context .kv_seqlens .tolist ()
65
+ max_q_seq_len = max (q_seqlens_list )
66
+ max_kv_seq_len = max (kv_seqlens_list )
67
+
76
68
for i in range (step_context .q_start_loc .size (0 )):
77
- q_seq_len = int (step_context .q_seqlens [i ])
78
- kv_seq_len = int (step_context .kv_seqlens [i ])
69
+ q_seq_len = q_seqlens_list [i ]
70
+ kv_seq_len = kv_seqlens_list [i ]
71
+
72
+ # collect kv start indices.
73
+ history_length = kv_seq_len - q_seq_len
74
+ slot_tables = total_slots [step_context .block_offsets [i ]].flatten ()
75
+ slot_indices = [p for p in range (history_length , kv_seq_len )]
76
+ slots = slot_tables [slot_indices ].reshape ((- 1 , 1 ))
77
+ kv_start_indices .append (slots )
78
+
79
+ # collect attention mask of paged_prefill attention stage.
79
80
if not (step_context .is_decoding or is_unpaged_prefill ):
80
81
single_attention_mask = torch .logical_not (
81
82
torch .tril (
82
- torch .ones (step_context . q_seqlens [ i ] ,
83
+ torch .ones (q_seq_len ,
83
84
step_context .block_offsets .shape [1 ] *
84
85
block_size ,
85
86
dtype = torch .bool ).cuda (),
86
- diagonal = step_context .kv_seqlens [i ] -
87
- step_context .q_seqlens [i ],
87
+ diagonal = kv_seq_len - q_seq_len ,
88
88
))
89
89
attention_mask .append (single_attention_mask )
90
- history_length = kv_seq_len - q_seq_len
91
- slot_tables = total_slots [step_context .block_offsets [i ]].flatten ()
92
- slot_indices = [p for p in range (history_length , kv_seq_len )]
93
- slots = slot_tables [slot_indices ].reshape ((- 1 , 1 ))
94
- kv_start_indices .append (slots )
90
+
95
91
kv_start_indices = torch .cat (kv_start_indices )
96
92
93
+ if step_context .is_decoding :
94
+ # prepare somae params of paged_decode attention stage.
95
+ q_start_loc_cpu , q_seqlens_cpu = None , None
96
+ kv_seqlens_cpu = step_context .kv_seqlens .cpu ()
97
+ elif is_unpaged_prefill :
98
+ # prepare somae params of unpaged_prefill attention stage.
99
+ q_start_loc_cpu , kv_seqlens_cpu = None , None
100
+ q_seqlens_cpu = step_context .q_seqlens .cpu ()
101
+ single_attention_mask = torch .logical_not (
102
+ torch .tril (
103
+ torch .ones (max_q_seq_len , max_kv_seq_len ,
104
+ dtype = torch .bool ).cuda (),
105
+ diagonal = max_kv_seq_len - max_q_seq_len ,
106
+ ))
107
+ attention_mask .append (single_attention_mask )
108
+ else :
109
+ # prepare somae params of paged_prefill attention stage.
110
+ q_start_loc_cpu , q_seqlens_cpu = None , None
111
+ kv_seqlens_cpu = step_context .kv_seqlens .repeat_interleave (
112
+ step_context .q_seqlens , 0 ).cpu ()
113
+ block_offsets_int32 = step_context .block_offsets .to (torch .int32 )
114
+ step_context .block_offsets = block_offsets_int32 .repeat_interleave (
115
+ step_context .q_seqlens , 0 )
116
+ attention_mask = [
117
+ torch .cat ([mask for mask in attention_mask ]).unsqueeze (1 )
118
+ ]
119
+
97
120
attn_meta_cls = cls .get_attention_metadata_cls ()
98
121
attn_metadata = attn_meta_cls (
99
122
step_context .is_decoding ,
0 commit comments