23
23
"""
24
24
import os
25
25
import argparse
26
- from typing import Literal , Union
26
+ from typing import List , Literal , Union
27
27
28
28
import uvicorn
29
29
30
30
from llama_cpp .server .app import create_app , Settings
31
31
32
- def get_non_none_base_types (annotation ):
33
- if not hasattr (annotation , "__args__" ):
34
- return annotation
35
- return [arg for arg in annotation .__args__ if arg is not type (None )][0 ]
36
-
37
32
def get_base_type (annotation ):
38
33
if getattr (annotation , '__origin__' , None ) is Literal :
39
34
return type (annotation .__args__ [0 ])
40
35
elif getattr (annotation , '__origin__' , None ) is Union :
41
36
non_optional_args = [arg for arg in annotation .__args__ if arg is not type (None )]
42
37
if non_optional_args :
43
38
return get_base_type (non_optional_args [0 ])
39
+ elif getattr (annotation , '__origin__' , None ) is list or getattr (annotation , '__origin__' , None ) is List :
40
+ return get_base_type (annotation .__args__ [0 ])
44
41
else :
45
42
return annotation
46
43
44
+ def contains_list_type (annotation ) -> bool :
45
+ origin = getattr (annotation , '__origin__' , None )
46
+
47
+ if origin is list or origin is List :
48
+ return True
49
+ elif origin in (Literal , Union ):
50
+ return any (contains_list_type (arg ) for arg in annotation .__args__ )
51
+ else :
52
+ return False
53
+
54
+
47
55
if __name__ == "__main__" :
48
56
parser = argparse .ArgumentParser ()
49
57
for name , field in Settings .model_fields .items ():
@@ -53,6 +61,7 @@ def get_base_type(annotation):
53
61
parser .add_argument (
54
62
f"--{ name } " ,
55
63
dest = name ,
64
+ nargs = "*" if contains_list_type (field .annotation ) else None ,
56
65
type = get_base_type (field .annotation ) if field .annotation is not None else str ,
57
66
help = description ,
58
67
)
0 commit comments