8000 Fix tensor_split cli option · csegura/llama-cpp-python@c4c440b · GitHub
[go: up one dir, main page]

Skip to content

Commit c4c440b

Browse files
committed
Fix tensor_split cli option
1 parent 203ede4 commit c4c440b

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

llama_cpp/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def __init__(
288288

289289
if self.tensor_split is not None:
290290
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
291-
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
291+
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
292292
self._c_tensor_split = FloatArray(
293293
*tensor_split
294294
) # keep a reference to the array so it is not gc'd

llama_cpp/server/__main__.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,35 @@
2323
"""
2424
import os
2525
import argparse
26-
from typing import Literal, Union
26+
from typing import List, Literal, Union
2727

2828
import uvicorn
2929

3030
from llama_cpp.server.app import create_app, Settings
3131

32-
def get_non_none_base_types(annotation):
33-
if not hasattr(annotation, "__args__"):
34-
return annotation
35-
return [arg for arg in annotation.__args__ if arg is not type(None)][0]
36-
3732
def get_base_type(annotation):
3833
if getattr(annotation, '__origin__', None) is Literal:
3934
return type(annotation.__args__[0])
4035
elif getattr(annotation, '__origin__', None) is Union:
4136
non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)]
4237
if non_optional_args:
4338
return get_base_type(non_optional_args[0])
39+
elif getattr(annotation, '__origin__', None) is list or getattr(annotation, '__origin__', None) is List:
40+
return get_base_type(annotation.__args__[0])
4441
else:
4542
return annotation
4643

44+
def contains_list_type(annotation) -> bool:
45+
origin = getattr(annotation, '__origin__', None)
46+
47+
if origin is list or origin is List:
48+
return True
49+
elif origin in (Literal, Union):
50+
return any(contains_list_type(arg) for arg in annotation.__args__)
51+
else:
52+
return False
53+
54+
4755
if __name__ == "__main__":
4856
parser = argparse.ArgumentParser()
4957
for name, field in Settings.model_fields.items():
@@ -53,6 +61,7 @@ def get_base_type(annotation):
5361
parser.add_argument(
5462
f"--{name}",
5563
dest=name,
64+
nargs="*" if contains_list_type(field.annotation) else None,
5665
type=get_base_type(field.annotation) if field.annotation is not None else str,
5766
help=description,
5867
)

0 commit comments

Comments
 (0)
0