19
19
from collections import deque , OrderedDict
20
20
21
21
import diskcache
22
+ import ctypes
22
23
23
24
from . import llama_cpp
24
25
from .llama_types import *
25
26
26
27
import numpy as np
27
28
import numpy .typing as npt
28
29
29
-
30
30
class BaseLlamaCache (ABC ):
31
31
"""Base cache class for a llama.cpp model."""
32
32
@@ -207,6 +207,7 @@ def __init__(
207
207
n_ctx : int = 512 ,
208
208
n_parts : int = - 1 ,
209
209
n_gpu_layers : int = 0 ,
210
+ tensor_split : list [float ] = None ,
210
211
seed : int = 1337 ,
211
212
f16_kv : bool = True ,
212
213
logits_all : bool = False ,
@@ -248,12 +249,20 @@ def __init__(
248
249
Returns:
249
250
A Llama instance.
250
251
"""
252
+ if tensor_split is None :
253
+ tensor_split = [0.0 ] * llama_cpp .LLAMA_MAX_DEVICES .value
254
+
255
+ #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
256
+ FloatArray = ctypes .c_float * llama_cpp .LLAMA_MAX_DEVICES .value
257
+ c_tensor_split = FloatArray (* tensor_split )
258
+
251
259
self .verbose = verbose
252
260
self .model_path = model_path
253
261
254
262
self .params = llama_cpp .llama_context_default_params ()
255
263
self .params .n_ctx = n_ctx
256
264
self .params .n_gpu_layers = n_gpu_layers
265
+ self .params .tensor_split = c_tensor_split
257
266
self .params .seed = seed
258
267
self .params .f16_kv = f16_kv
259
268
self .params .logits_all = logits_all
@@ -1490,6 +1499,7 @@ def __getstate__(self):
1490
1499
model_path = self .model_path ,
1491
1500
n_ctx = self .params .n_ctx ,
1492
1501
n_gpu_layers = self .params .n_gpu_layers ,
1502
+ tensor_split = self .params .tensor_split ,
1493
1503
seed = self .params .seed ,
1494
1504
f16_kv = self .params .f16_kv ,
1495
1505
logits_all = self .params .logits_all ,
@@ -1514,6 +1524,7 @@ def __setstate__(self, state):
1514
1524
n_ctx = state ["n_ctx" ],
1515
1525
n_parts = state ["n_parts" ],
1516
1526
n_gpu_layers = state ["n_gpu_layers" ],
1527
+ tensor_split = state ["tensor_split" ],
1517
1528
seed = state ["seed" ],
1518
1529
f16_kv = state ["f16_kv" ],
1519
1530
logits_all = state ["logits_all" ],
0 commit comments