@@ -240,11 +240,11 @@ class llama_token_data_array(Structure):
240
240
# typedef struct llama_batch {
241
241
# int32_t n_tokens;
242
242
243
- # llama_token * token;
244
- # float * embd;
245
- # llama_pos * pos;
246
- # llama_seq_id * seq_id;
247
- # int8_t * logits;
243
+ # llama_token * token;
244
+ # float * embd;
245
+ # llama_pos * pos;
246
+ # llama_seq_id ** seq_id;
247
+ # int8_t * logits;
248
248
249
249
250
250
# // NOTE: helpers for smooth API transition - can be deprecated in the future
@@ -262,7 +262,7 @@ class llama_batch(Structure):
262
262
("token" , POINTER (llama_token )),
263
263
("embd" , c_float_p ),
264
264
("pos" , POINTER (llama_pos )),
265
- ("seq_id" , POINTER (llama_seq_id )),
265
+ ("seq_id" , POINTER (POINTER ( llama_seq_id ) )),
266
266
("logits" , POINTER (c_int8 )),
267
267
("all_pos_0" , llama_pos ),
268
268
("all_pos_1" , llama_pos ),
@@ -1069,22 +1069,26 @@ def llama_batch_get_one(
1069
1069
_lib .llama_batch_get_one .restype = llama_batch
1070
1070
1071
1071
1072
- # // Allocates a batch of tokens on the heap
1072
+ # // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
1073
+ # // Each token can be assigned up to n_seq_max sequence ids
1073
1074
# // The batch has to be freed with llama_batch_free()
1074
1075
# // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
1075
1076
# // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
1076
1077
# // The rest of the llama_batch members are allocated with size n_tokens
1077
1078
# // All members are left uninitialized
1078
1079
# LLAMA_API struct llama_batch llama_batch_init(
1079
1080
# int32_t n_tokens,
1080
- # int32_t embd);
1081
+ # int32_t embd,
1082
+ # int32_t n_seq_max);
1081
1083
def llama_batch_init (
1082
- n_tokens : Union [c_int , int ], embd : Union [c_int , int ]
1084
+ n_tokens : Union [c_int32 , int ],
1085
+ embd : Union [c_int32 , int ],
1086
+ n_seq_max : Union [c_int32 , int ],
1083
1087
) -> llama_batch :
1084
- return _lib .llama_batch_init (n_tokens , embd )
1088
+ return _lib .llama_batch_init (n_tokens , embd , n_seq_max )
1085
1089
1086
1090
1087
- _lib .llama_batch_init .argtypes = [c_int , c_int ]
1091
+ _lib .llama_batch_init .argtypes = [c_int32 , c_int32 , c_int32 ]
1088
1092
_lib .llama_batch_init .restype = llama_batch
1089
1093
1090
1094
@@ -1308,6 +1312,46 @@ def llama_tokenize(
1308
1312
_lib .llama_tokenize .restype = c_int
1309
1313
1310
1314
1315
+ # /// @details Convert the provided text into tokens.
1316
+ # /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
1317
+ # /// @return Returns the number of tokens on success, no more than n_max_tokens
1318
+ # /// @return Returns a negative number on failure - the number of tokens that would have been returned
1319
+ # /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
1320
+ # /// Does not insert a leading space.
1321
+ # LLAMA_API int llama_tokenize(
1322
+ # const struct llama_model * model,
1323
+ # const char * text,
1324
+ # int text_len,
1325
+ # llama_token * tokens,
1326
+ # int n_max_tokens,
1327
+ # bool add_bos,
1328
+ # bool special);
1329
+ def llama_tokenize (
1330
+ model : llama_model_p ,
1331
+ text : bytes ,
1332
+ text_len : Union [c_int , int ],
1333
+ tokens , # type: Array[llama_token]
1334
+ n_max_tokens : Union [c_int , int ],
1335
+ add_bos : Union [c_bool , bool ],
1336
+ special : Union [c_bool , bool ],
1337
+ ) -> int :
1338
+ return _lib .llama_tokenize (
1339
+ model , text , text_len , tokens , n_max_tokens , add_bos , special
1340
+ )
1341
+
1342
+
1343
+ _lib .llama_tokenize .argtypes = [
1344
+ llama_model_p ,
1345
+ c_char_p ,
1346
+ c_int ,
1347
+ llama_token_p ,
1348
+ c_int ,
1349
+ c_bool ,
1350
+ c_bool ,
1351
+ ]
1352
+ _lib .llama_tokenize .restype = c_int
1353
+
1354
+
1311
1355
# // Token Id -> Piece.
1312
1356
# // Uses the vocabulary in the provided context.
1313
1357
# // Does not write null terminator to the buffer.
0 commit comments