@@ -207,7 +207,6 @@ def __init__(
207
207
n_ctx : int = 512 ,
208
208
n_parts : int = - 1 ,
209
209
n_gpu_layers : int = 0 ,
210
- tensor_split : list [float ] = None ,
211
210
seed : int = 1337 ,
212
211
f16_kv : bool = True ,
213
212
logits_all : bool = False ,
@@ -221,6 +220,7 @@ def __init__(
221
220
lora_base : Optional [str ] = None ,
222
221
lora_path : Optional [str ] = None ,
223
222
low_vram : bool = False ,
223
+ tensor_split : Optional [List [float ]] = None ,
224
224
verbose : bool = True ,
225
225
):
226
226
"""Load a llama.cpp model from `model_path`.
@@ -241,6 +241,7 @@ def __init__(
241
241
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
242
242
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
243
243
lora_path: Path to a LoRA file to apply to the model.
244
+ tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
244
245
verbose: Print verbose output to stderr.
245
246
246
247
Raises:
@@ -249,20 +250,13 @@ def __init__(
249
250
Returns:
250
251
A Llama instance.
251
252
"""
252
- if tensor_split is None :
253
- tensor_split = [0.0 ] * llama_cpp .LLAMA_MAX_DEVICES .value
254
-
255
- #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
256
- FloatArray = ctypes .c_float * llama_cpp .LLAMA_MAX_DEVICES .value
257
- c_tensor_split = FloatArray (* tensor_split )
258
253
259
254
self .verbose = verbose
260
255
self .model_path = model_path
261
256
262
257
self .params = llama_cpp .llama_context_default_params ()
263
258
self .params .n_ctx = n_ctx
264
259
self .params .n_gpu_layers = n_gpu_layers
265
- self .params .tensor_split = c_tensor_split
266
260
self .params .seed = seed
267
261
self .params .f16_kv = f16_kv
268
262
self .params .logits_all = logits_all
@@ -272,6 +266,15 @@ def __init__(
272
266
self .params .embedding = embedding
273
267
self .params .low_vram = low_vram
274
268
269
+ self .tensor_split = tensor_split
270
+ self ._c_tensor_split = None
271
+
272
+ if self .tensor_split is not None :
273
+ #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
274
+ FloatArray = ctypes .c_float * llama_cpp .LLAMA_MAX_DEVICES .value
275
+ self ._c_tensor_split = FloatArray (* tensor_split ) # keep a reference to the array so it is not gc'd
276
+ self .params .tensor_split = self ._c_tensor_split
277
+
275
278
self .last_n_tokens_size = last_n_tokens_size
276
279
self .n_batch = min (n_ctx , n_batch )
277
280
@@ -1499,7 +1502,6 @@ def __getstate__(self):
1499
1502
model_path = self .model_path ,
1500
1503
n_ctx = self .params .n_ctx ,
1501
1504
n_gpu_layers = self .params .n_gpu_layers ,
1502
- tensor_split = self .params .tensor_split ,
1503
1505
seed = self .params .seed ,
1504
1506
f16_kv = self .params .f16_kv ,
1505
1507
logits_all = self .params .logits_all ,
@@ -1513,6 +1515,7 @@ def __getstate__(self):
1513
1515
n_threads = self .n_threads ,
1514
1516
lora_base = self .lora_base ,
1515
1517
lora_path = self .lora_path ,
1518
+ tensor_split = self .tensor_split ,
1516
1519
### DEPRECATED ###
1517
1520
n_parts = self .n_parts ,
1518
1521
### DEPRECATED ###
@@ -1524,7 +1527,6 @@ def __setstate__(self, state):
1524
1527
n_ctx = state ["n_ctx" ],
1525
1528
n_parts = state ["n_parts" ],
1526
1529
n_gpu_layers = state ["n_gpu_layers" ],
1527
- tensor_split = state ["tensor_split" ],
1528
1530
seed = state ["seed" ],
1529
1531
f16_kv = state ["f16_kv" ],
1530
1532
logits_all = state ["logits_all" ],
@@ -1538,6 +1540,7 @@ def __setstate__(self, state):
1538
1540
last_n_tokens_size = state ["last_n_tokens_size" ],
1539
1541
lora_base = state ["lora_base" ],
1540
1542
lora_path = state ["lora_path" ],
1543
+ tensor_split = state ["tensor_split" ],
1541
1544
verbose = state ["verbose" ],
1542
1545
)
1543
1546
0 commit comments