8000 Merge pull request #2 from kossum/main · bot08/llama-cpp-python@12bac61 · GitHub
[go: up one dir, main page]

Skip to content

Commit 12bac61

Browse files
authored
Merge pull request #2 from kossum/main
feat: Add Gemma3 chat handler
2 parents 2beea0d + 025e7fa commit 12bac61

File tree

2 files changed

+267
-18
lines changed

2 files changed

+267
-18
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 154 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,24 +2835,7 @@ def __call__(
28352835
)
28362836
llama.eval(tokens)
28372837
else:
2838-
image_bytes = self.load_image(value)
2839-
embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch)
2840-
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
2841-
raise ValueError(
2842-
f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}"
2843-
)
2844-
n_past = ctypes.c_int(llama.n_tokens)
2845-
n_past_p = ctypes.pointer(n_past)
2846-
with suppress_stdout_stderr(disable=self.verbose):
2847-
self._llava_cpp.llava_eval_image_embed(
2848-
llama.ctx,
2849-
embed,
2850-
llama.n_batch,
2851-
n_past_p,
2852-
)
2853-
# Required to avoid issues with hf tokenizer
2854-
llama.input_ids[llama.n_tokens : n_past.value] = -1
2855-
llama.n_tokens = n_past.value
2838+
self.eval_image(llama, value)
28562839

28572840
# Get prompt tokens to avoid a cache miss
28582841
prompt = llama.input_ids[: llama.n_tokens].tolist()
@@ -2938,6 +2921,26 @@ def __call__(
29382921
)
29392922
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
29402923

2924+
def eval_image(self, llama: llama.Llama, image_url: str):
2925+
image_bytes = self.load_image(image_url)
2926+
embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch)
2927+
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
2928+
raise ValueError(
2929+
f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}"
2930+
)
2931+
n_past = ctypes.c_int(llama.n_tokens)
2932+
n_past_p = ctypes.pointer(n_past)
2933+
with suppress_stdout_stderr(disable=self.verbose):
2934+
self._llava_cpp.llava_eval_image_embed(
2935+
llama.ctx,
2936+
embed,
2937+
llama.n_batch,
2938+
n_past_p,
2939+
)
2940+
# Required to avoid issues with hf tokenizer
2941+
llama.input_ids[llama.n_tokens : n_past.value] = -1
2942+
llama.n_tokens = n_past.value
2943+
29412944
@staticmethod
29422945
def _load_image(image_url: str) -> bytes:
29432946
# TODO: Add Pillow support for other image formats beyond (jpg, png)
@@ -3373,6 +3376,139 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler):
33733376
)
33743377

33753378

3379+
class Gemma3ChatHandler(Llava15ChatHandler):
3380+
# Chat Format:
3381+
# '<bos><start_of_turn>user\n{system_prompt}\n\n{prompt}<end_of_turn>\n<start_of_turn>model\n'
3382+
3383+
DEFAULT_SYSTEM_MESSAGE = None
3384+
3385+
CHAT_FORMAT = (
3386+
"{{ '<bos>' }}"
3387+
"{%- if messages[0]['role'] == 'system' -%}"
3388+
"{%- if messages[0]['content'] is string -%}"
3389+
"{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}"
3390+
"{%- else -%}"
3391+
"{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}"
3392+
"{%- endif -%}"
3393+
"{%- set loop_messages = messages[1:] -%}"
3394+
"{%- else -%}"
3395+
"{%- set first_user_prefix = \"\" -%}"
3396+
"{%- set loop_messages = messages -%}"
3397+
"{%- endif -%}"
3398+
"{%- for message in loop_messages -%}"
3399+
"{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}"
3400+
"{{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}"
3401+
"{%- endif -%}"
3402+
"{%- if (messa A3D4 ge['role'] == 'assistant') -%}"
3403+
"{%- set role = \"model\" -%}"
3404+
"{%- else -%}"
3405+
"{%- set role = message['role'] -%}"
3406+
"{%- endif -%}"
3407+
"{{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}"
3408+
"{%- if message['content'] is string -%}"
3409+
"{{ message['content'] | trim }}"
3410+
"{%- elif message['content'] is iterable -%}"
3411+
"{%- for item in message['content'] -%}"
3412+
"{%- if item['type'] == 'image_url' -%}"
3413+
"{{ '<start_of_image>' }}"
3414+
"{%- elif item['type'] == 'text' -%}"
3415+
"{{ item['text'] | trim }}"
3416+
"{%- endif -%}"
3417+
"{%- endfor -%}"
3418+
"{%- else -%}"
3419+
"{{ raise_exception(\"Invalid content type\") }}"
3420+
"{%- endif -%}"
3421+
"{{ '<end_of_turn>\n' }}"
3422+
"{%- endfor -%}"
3423+
"{%- if add_generation_prompt -%}"
3424+
"{{ '<start_of_turn>model\n' }}"
3425+
"{%- endif -%}"
3426+
)
3427+
3428+
@staticmethod
3429+
def split_text_on_image_urls(text: str, image_urls: List[str]):
3430+
split_text: List[Tuple[Literal["text", "image_url"], str]] = []
3431+
copied_urls = image_urls[:]
3432+
remaining = text
3433+
image_placeholder = "<start_of_image>"
3434+
3435+
while remaining:
3436+
# Find placeholder
3437+
pos = remaining.find(image_placeholder)
3438+
if pos != -1:
3439+
assert len(copied_urls) > 0
3440+
if pos > 0:
3441+
split_text.append(("text", remaining[:pos]))
3442+
split_text.append(("text", "\n\n<start_of_image>"))
3443+
split_text.append(("image_url", copied_urls.pop(0)))
3444+
split_text.append(("text", "<end_of_image>\n\n"))
3445+
remaining = remaining[pos + len(image_placeholder):]
3446+
else:
3447+
assert len(copied_urls) == 0
3448+
split_text.append(("text", remaining))
3449+
remaining = ""
3450+
return split_text
3451+
3452+
def eval_image(self, llama: llama.Llama, image_url: str):
3453+
import llama_cpp
3454+
3455+
n_tokens = 256
3456+
if llama.n_tokens + n_tokens > llama.n_ctx():
3457+
raise ValueError(
3458+
f"Prompt exceeds n_ctx: {llama.n_tokens + n_tokens} > {llama.n_ctx()}"
3459+
)
3460+
3461+
img_bytes = self.load_image(image_url)
3462+
img_u8_p = self._llava_cpp.clip_image_u8_init()
3463+
if not self._llava_cpp.clip_image_load_from_bytes(
3464+
ctypes.create_string_buffer(img_bytes, len(img_bytes)),
3465+
ctypes.c_size_t(len(img_bytes)),
3466+
img_u8_p,
3467+
):
3468+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3469+
raise ValueError("Failed to load image.")
3470+
3471+
img_f32 = self._llava_cpp.clip_image_f32_batch()
3472+
img_f32_p = ctypes.byref(img_f32)
3473+
if not self._llava_cpp.clip_image_preprocess(self.clip_ctx, img_u8_p, img_f32_p):
3474+
self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
3475+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3476+
raise ValueError("Failed to preprocess image.")
3477+
3478+
n_embd = llama_cpp.llama_model_n_embd(llama._model.model)
3479+
embed = (ctypes.c_float * (n_tokens * n_embd))()
3480+
if not self._llava_cpp.clip_image_batch_encode(self.clip_ctx, llama.n_threads, img_f32_p, embed):
3481+
self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
3482+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3483+
raise ValueError("Failed to encode image.")
3484+
3485+
self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
3486+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3487+
llama_cpp.llama_set_causal_attn(llama.ctx, False)
3488+
3489+
seq_id_0 = (ctypes.c_int32 * 1)()
3490+
seq_ids = (ctypes.POINTER(ctypes.c_int32) * (n_tokens + 1))()
3491+
for i in range(n_tokens):
3492+
seq_ids[i] = seq_id_0
3493+
3494+
batch = llama_cpp.llama_batch()
3495+
batch.n_tokens = n_tokens
3496+
batch.token = None
3497+
batch.embd = embed
3498+
batch.pos = (ctypes.c_int32 * n_tokens)(*[i + llama.n_tokens for i in range(n_tokens)])
3499+
batch.seq_id = seq_ids
3500+
batch.n_seq_id = (ctypes.c_int32 * n_tokens)(*([1] * n_tokens))
3501+
batch.logits = (ctypes.c_int8 * n_tokens)()
3502+
3503+
if llama_cpp.llama_decode(llama.ctx, batch):
3504+
raise ValueError("Failed to decode image.")
3505+
3506+
llama_cpp.llama_set_causal_attn(llama.ctx, True)
3507+
# Required to avoid issues with hf tokenizer
3508+
llama.input_ids[llama.n_tokens : llama.n_tokens + n_tokens] = -1
3509+
llama.n_tokens += n_tokens
3510+
3511+
33763512
@register_chat_completion_handler("chatml-function-calling")
33773513
def chatml_function_calling(
33783514
llama: llama.Llama,

llama_cpp/llava_cpp.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
c_int,
88
c_uint8,
99
c_float,
10+
c_size_t,
1011
c_void_p,
1112
POINTER,
1213
_Pointer, # type: ignore
@@ -141,6 +142,28 @@ def llava_eval_image_embed(
141142
################################################
142143

143144

145+
# struct clip_image_u8_batch {
146+
# struct clip_image_u8 * data;
147+
# size_t size;
148+
# };
149+
class clip_image_u8_batch(Structure):
150+
_fields_ = [
151+
("data", c_void_p),
152+
("size", c_size_t),
153+
]
154+
155+
156+
# struct clip_image_f32_batch {
157+
# struct clip_image_f32 * data;
158+
# size_t size;
159+
# };
160+
class clip_image_f32_batch(Structure):
161+
_fields_ = [
162+
("data", c_void_p),
163+
("size", c_size_t),
164+
]
165+
166+
144167
# /** load mmproj model */
145168
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
146169
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
@@ -156,3 +179,93 @@ def clip_model_load(
156179
def clip_free(ctx: clip_ctx_p, /):
157180
...
158181

182+
183+
# CLIP_API struct clip_image_u8 * clip_image_u8_init ();
184+
@ctypes_function("clip_image_u8_init", [], c_void_p)
185+
def clip_image_u8_init() -> Optional[c_void_p]:
186+
...
187+
188+
189+
# CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
190+
@ctypes_function("clip_image_u8_free", [c_void_p], None)
191+
def clip_image_u8_free(img: c_void_p, /):
192+
...
193+
194+
195+
# CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
196+
@ctypes_function("clip_image_f32_free", [c_void_p], None)
197+
def clip_image_f32_free(img: c_void_p, /):
198+
...
199+
200+
201+
# CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
202+
@ctypes_function("clip_image_u8_batch_free", [POINTER(clip_image_u8_batch)], None)
203+
def clip_image_u8_batch_free(batch: "_Pointer[clip_image_u8_batch]", /):
204+
...
205+
206+
207+
# CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
208+
@ctypes_function("clip_image_f32_batch_free", [POINTER(clip_image_f32_batch)], None)
209+
def clip_image_f32_batch_free(batch: "_Pointer[clip_image_f32_batch]", /):
210+
...
211+
212+
213+
# /** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
214+
# CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
215+
@ctypes_function(
216+
"clip_image_preprocess",
217+
[
218+
clip_ctx_p_ctypes,
219+
c_void_p,
220+
POINTER(clip_image_f32_batch),
221+
],
222+
c_bool,
223+
)
224+
def clip_image_preprocess(
225+
ctx: clip_ctx_p,
226+
img: c_void_p,
227+
res_imgs: "_Pointer[clip_image_f32_batch]",
228+
/,
229+
) -> bool:
230+
...
231+
232+
233+
# CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
234+
@ctypes_function(
235+
"clip_image_batch_encode",
236+
[
237+
clip_ctx_p_ctypes,
238+
c_int,
239+
POINTER(clip_image_f32_batch),
240+
POINTER(c_float),
241+
],
242+
c_bool,
243+
)
244+
def clip_image_batch_encode(
245+
ctx: clip_ctx_p,
246+
n_threads: c_int,
247+
imgs: "_Pointer[clip_image_f32_batch]",
248+
vec: c_void_p,
249+
/,
250+
) -> bool:
251+
...
252+
253+
254+
# /** interpret bytes as an image file with length bytes_length, and use the result to populate img */
255+
# CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
256+
@ctypes_function(
257+
"clip_image_load_from_bytes",
258+
[
259+
c_void_p,
260+
c_size_t,
261+
c_void_p,
262+
],
263+
c_bool,
264+
)
265+
def clip_image_load_from_bytes(
266+
bytes: c_void_p,
267+
bytes_length: c_size_t,
268+
img: c_void_p,
269+
/,
270+
) -> bool:
271+
...

0 commit comments

Comments
 (0)
0