8000 Minor fix to tensor_split parameter · MobinX/llama-cpp-python@25b3494 · GitHub
[go: up one dir, main page]

Skip to content

Commit 25b3494

Browse files
committed
Minor fix to tensor_split parameter
1 parent e6c67c8 commit 25b3494

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

llama_cpp/llama.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def __init__(
207207
n_ctx: int = 512,
208208
n_parts: int = -1,
209209
n_gpu_layers: int = 0,
210-
tensor_split: list[float] = None,
211210
seed: int = 1337,
212211
f16_kv: bool = True,
213212
logits_all: bool = False,
@@ -221,6 +220,7 @@ def __init__(
221220
lora_base: Optional[str] = None,
222221
lora_path: Optional[str] = None,
223222
low_vram: bool = False,
223+
tensor_split: Optional[List[float]] = None,
224224
verbose: bool = True,
225225
):
226226
"""Load a llama.cpp model from `model_path`.
@@ -241,6 +241,7 @@ def __init__(
241241
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
242242
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
243243
lora_path: Path to a LoRA file to apply to the model.
244+
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
244245
verbose: Print verbose output to stderr.
245246
246247
Raises:
@@ -249,20 +250,13 @@ def __init__(
249250
Returns:
250251
A Llama instance.
251252
"""
252-
if tensor_split is None:
253-
tensor_split = [0.0] * llama_cpp.LLAMA_MAX_DEVICES.value
254-
255-
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
256-
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
257-
c_tensor_split = FloatArray(*tensor_split)
258253

259254
self.verbose = verbose
260255
self.model_path = model_path
261256

262257
self.params = llama_cpp.llama_context_default_params()
263258
self.params.n_ctx = n_ctx
264259
self.params.n_gpu_layers = n_gpu_layers
265-
self.params.tensor_split = c_tensor_split
266260
self.params.seed = seed
267261
self.params.f16_kv = f16_kv
268262
self.params.logits_all = logits_all
@@ -272,6 +266,15 @@ def __init__(
272266
self.params.embedding = embedding
273267
self.params.low_vram = low_vram
274268

269+
self.tensor_split = tensor_split
270+
self._c_tensor_split = None
271+
272+
if self.tensor_split is not None:
273+
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
274+
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
275+
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
276+
self.params.tensor_split = self._c_tensor_split
277+
275278
self.last_n_tokens_size = last_n_tokens_size
276279
self.n_batch = min(n_ctx, n_batch)
277280

@@ -1499,7 +1502,6 @@ def __getstate__(self):
14991502
model_path=self.model_path,
15001503
n_ctx=self.params.n_ctx,
15011504
n_gpu_layers=self.params.n_gpu_layers,
1502-
tensor_split=self.params.tensor_split,
15031505
seed=self.params.seed,
15041506
f16_kv=self.params.f16_kv,
15051507
logits_all=self.params.logits_all,
@@ -1513,6 +1515,7 @@ def __getstate__(self):
15131515
n_threads=self.n_threads,
15141516
lora_base=self.lora_base,
15151517
lora_path=self.lora_path,
1518+
tensor_split=self.tensor_split,
15161519
### DEPRECATED ###
15171520
n_parts=self.n_parts,
15181521
### DEPRECATED ###
@@ -1524,7 +1527,6 @@ def __setstate__(self, state):
15241527
n_ctx=state["n_ctx"],
15251528
n_parts=state["n_parts"],
15261529
n_gpu_layers=state["n_gpu_layers"],
1527-
tensor_split=state["tensor_split"],
15281530
seed=state["seed"],
15291531
f16_kv=state["f16_kv"],
15301532
logits_all=state["logits_all"],
@@ -1538,6 +1540,7 @@ def __setstate__(self, state):
15381540
last_n_tokens_size=state["last_n_tokens_size"],
15391541
lora_base=state["lora_base"],
15401542
lora_path=state["lora_path"],
1543+
tensor_split=state["tensor_split"],
15411544
verbose=state["verbose"],
15421545
)
15431546

0 commit comments

Comments
 (0)
0