5
5
#include < algorithm>
6
6
7
7
void llama_ubatch::update () {
8
+ if (equal_seqs) {
9
+ // TODO: for now don't compute min/max for recurrent batches since we don't need this.
10
+ // the batches will be refactored anyway, so we'll fix this later
11
+ return ;
12
+ }
13
+
8
14
for (uint32_t i = 0 ; i < n_tokens; ++i) {
9
15
const llama_seq_id s = seq_id[i][0 ];
10
16
@@ -24,26 +30,33 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
24
30
break ;
25
31
}
26
32
}
27
- ubatch_token.resize (!has_embd ? n_ubatch : 0 );
28
- ubatch_embd.resize (has_embd ? n_embd * n_ubatch : 0 );
29
- ubatch_pos.resize (n_ubatch);
30
- ubatch_n_seq_id.resize (n_ubatch);
31
- ubatch_seq_id.resize (n_ubatch);
32
- ubatch_output.resize (n_ubatch);
33
+
34
+ udatas.push_back ({});
35
+
8000
36
+ auto & udata = udatas.back ();
37
+
38
+ udata.token .resize (!has_embd ? n_ubatch : 0 );
39
+ udata.embd .resize (has_embd ? n_embd * n_ubatch : 0 );
40
+ udata.pos .resize (n_ubatch);
41
+ udata.n_seq_id .resize (n_ubatch);
42
+ udata.seq_id .resize (n_ubatch);
43
+ udata.output .resize (n_ubatch);
44
+
33
45
llama_ubatch ubatch = {
34
46
/* equal_seqs =*/ true ,
35
47
/* n_tokens =*/ 0 ,
36
48
/* n_seq_tokens =*/ 0 ,
37
49
/* n_seqs =*/ 0 ,
38
50
/* seq_pos_min =*/ {-1 },
39
51
/* seq_pos_max =*/ {-1 },
40
- /* token =*/ !has_embd ? ubatch_token .data () : nullptr ,
41
- /* embd =*/ has_embd ? ubatch_embd .data () : nullptr ,
42
- /* pos =*/ ubatch_pos .data (),
43
- /* n_seq_id =*/ ubatch_n_seq_id .data (),
44
- /* seq_id =*/ ubatch_seq_id .data (),
45
- /* output =*/ ubatch_output .data (),
52
+ /* token =*/ !has_embd ? udata. token .data () : nullptr ,
53
+ /* embd =*/ has_embd ? udata. embd .data () : nullptr ,
54
+ /* pos =*/ udata. pos .data (),
55
+ /* n_seq_id =*/ udata. n_seq_id .data (),
56
+ /* seq_id =*/ udata. seq_id .data (),
57
+ /* output =*/ udata. output .data (),
46
58
};
59
+
47
60
return ubatch;
48
61
}
49
62
0 commit comments