1
1
import sys
2
2
import os
3
3
import ctypes
4
+ import functools
4
5
from ctypes import (
5
6
c_bool ,
6
7
c_char_p ,
13
14
Structure ,
14
15
)
15
16
import pathlib
16
- from typing import List , Union , NewType , Optional
17
+ from typing import List , Union , NewType , Optional , TypeVar , Callable , Any
17
18
18
19
import llama_cpp .llama_cpp as llama_cpp
19
20
@@ -76,6 +77,31 @@ def _load_shared_library(lib_base_name: str):
76
77
# Load the library
77
78
_libllava = _load_shared_library (_libllava_base_name )
78
79
80
+ # ctypes helper
81
+
82
+ F = TypeVar ("F" , bound = Callable [..., Any ])
83
+
84
+ def ctypes_function_for_shared_library (lib : ctypes .CDLL ):
85
+ def ctypes_function (
86
+ name : str , argtypes : List [Any ], restype : Any , enabled : bool = True
87
+ ):
88
+ def decorator (f : F ) -> F :
89
+ if enabled :
90
+ func = getattr (lib , name )
91
+ func .argtypes = argtypes
92
+ func .restype = restype
93
+ functools .wraps (f )(func )
94
+ return func
95
+ else :
96
+ return f
97
+
98
+ return decorator
99
+
100
+ return ctypes_function
101
+
102
+
103
+ ctypes_function = ctypes_function_for_shared_library (_libllava )
104
+
79
105
80
106
################################################
81
107
# llava.h
@@ -97,49 +123,35 @@ class llava_image_embed(Structure):
97
123
98
124
# /** sanity check for clip <-> llava embed size match */
99
125
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
126
+ @ctypes_function ("llava_validate_embed_size" , [llama_cpp .llama_context_p_ctypes , clip_ctx_p_ctypes ], c_bool )
100
127
def llava_validate_embed_size (ctx_llama : llama_cpp .llama_context_p , ctx_clip : clip_ctx_p , / ) -> bool :
101
128
...
102
129
103
- llava_validate_embed_size = _libllava .llava_validate_embed_size
104
- llava_validate_embed_size .argtypes = [llama_cpp .llama_context_p_ctypes , clip_ctx_p_ctypes ]
105
- llava_validate_embed_size .restype = c_bool
106
130
107
131
# /** build an image embed from image file bytes */
108
132
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
133
+ @ctypes_function ("llava_image_embed_make_with_bytes" , [clip_ctx_p_ctypes , c_int , POINTER (c_uint8 ), c_int ], POINTER (llava_image_embed ))
109
134
def llava_image_embed_make_with_bytes (ctx_clip : clip_ctx_p , n_threads : Union [c_int , int ], image_bytes : bytes , image_bytes_length : Union [c_int , int ], / ) -> "_Pointer[llava_image_embed]" :
110
135
...
111
136
112
- llava_image_embed_make_with_bytes = _libllava .llava_image_embed_make_with_bytes
113
- llava_image_embed_make_with_bytes .argtypes = [clip_ctx_p_ctypes , c_int , POINTER (c_uint8 ), c_int ]
114
- llava_image_embed_make_with_bytes .restype = POINTER (llava_image_embed )
115
-
116
137
# /** build an image embed from a path to an image filename */
117
138
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
139
+ @ctypes_function ("llava_image_embed_make_with_filename" , [clip_ctx_p_ctypes , c_int , c_char_p ], POINTER (llava_image_embed ))
118
140
def llava_image_embed_make_with_filename (ctx_clip : clip_ctx_p , n_threads : Union [c_int , int ], image_path : bytes , / ) -> "_Pointer[llava_image_embed]" :
119
141
...
120
142
121
- llava_image_embed_make_with_filename = _libllava .llava_image_embed_make_with_filename
122
- llava_image_embed_make_with_filename .argtypes = [clip_ctx_p_ctypes , c_int , c_char_p ]
123
- llava_image_embed_make_with_filename .restype = POINTER (llava_image_embed )
124
-
125
143
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
126
144
# /** free an embedding made with llava_image_embed_make_* */
145
+ @ctypes_function ("llava_image_embed_free" , [POINTER (llava_image_embed )], None )
127
146
def llava_image_embed_free (embed : "_Pointer[llava_image_embed]" , / ):
128
147
...
129
148
130
- llava_image_embed_free = _libllava .llava_image_embed_free
131
- llava_image_embed_free .argtypes = [POINTER (llava_image_embed )]
132
- llava_image_embed_free .restype = None
133
-
134
149
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
135
150
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
151
+ @ctypes_function ("llava_eval_image_embed" , [llama_cpp .llama_context_p_ctypes , POINTER (llava_image_embed ), c_int , POINTER (c_int )], c_bool )
136
152
def llava_eval_image_embed (ctx_llama : llama_cpp .llama_context_p , embed : "_Pointer[llava_image_embed]" , n_batch : Union [c_int , int ], n_past : "_Pointer[c_int]" , / ) -> bool :
137
153
...
138
154
139
- llava_eval_image_embed = _libllava .llava_eval_image_embed
140
- llava_eval_image_embed .argtypes = [llama_cpp .llama_context_p_ctypes , POINTER (llava_image_embed ), c_int , POINTER (c_int )]
141
- llava_eval_image_embed .restype = c_bool
142
-
143
155
144
156
################################################
145
157
# clip.h
@@ -148,18 +160,12 @@ def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointe
148
160
149
161
# /** load mmproj model */
150
162
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
163
+ @ctypes_function ("clip_model_load" , [c_char_p , c_int ], clip_ctx_p_ctypes )
151
164
def clip_model_load (fname : bytes , verbosity : Union [c_int , int ], / ) -> Optional [clip_ctx_p ]:
152
165
...
153
166
154
- clip_model_load = _libllava .clip_model_load
155
- clip_model_load .argtypes = [c_char_p , c_int ]
156
- clip_model_load .restype = clip_ctx_p_ctypes
157
-
158
167
# /** free mmproj model */
159
168
# CLIP_API void clip_free(struct clip_ctx * ctx);
169
+ @ctypes_function ("clip_free" , [clip_ctx_p_ctypes ], None )
160
170
def clip_free (ctx : clip_ctx_p , / ):
161
171
...
162
-
163
- clip_free = _libllava .clip_free
164
- clip_free .argtypes = [clip_ctx_p_ctypes ]
165
- clip_free .restype = None
0 commit comments