@@ -2835,24 +2835,7 @@ def __call__(
2835
2835
)
2836
2836
llama .eval (tokens )
2837
2837
else :
2838
- image_bytes = self .load_image (value )
2839
- embed = self ._embed_image_bytes (image_bytes , llama .context_params .n_threads_batch )
2840
- if llama .n_tokens + embed .contents .n_image_pos > llama .n_ctx ():
2841
- raise ValueError (
2842
- f"Prompt exceeds n_ctx: { llama .n_tokens + embed .contents .n_image_pos } > { llama .n_ctx ()} "
2843
- )
2844
- n_past = ctypes .c_int (llama .n_tokens )
2845
- n_past_p = ctypes .pointer (n_past )
2846
- with suppress_stdout_stderr (disable = self .verbose ):
2847
- self ._llava_cpp .llava_eval_image_embed (
2848
- llama .ctx ,
2849
- embed ,
2850
- llama .n_batch ,
2851
- n_past_p ,
2852
- )
2853
- # Required to avoid issues with hf tokenizer
2854
- llama .input_ids [llama .n_tokens : n_past .value ] = - 1
2855
- llama .n_tokens = n_past .value
2838
+ self .eval_image (llama , value )
2856
2839
2857
2840
# Get prompt tokens to avoid a cache miss
2858
2841
prompt = llama .input_ids [: llama .n_tokens ].tolist ()
@@ -2938,6 +2921,26 @@ def __call__(
2938
2921
)
2939
2922
return _convert_completion_to_chat (completion_or_chunks , stream = stream )
2940
2923
2924
+ def eval_image (self , llama : llama .Llama , image_url : str ):
2925
+ image_bytes = self .load_image (image_url )
2926
+ embed = self ._embed_image_bytes (image_bytes , llama .context_params .n_threads_batch )
2927
+ if llama .n_tokens + embed .contents .n_image_pos > llama .n_ctx ():
2928
+ raise ValueError (
2929
+ f"Prompt exceeds n_ctx: { llama .n_tokens + embed .contents .n_image_pos } > { llama .n_ctx ()} "
2930
+ )
2931
+ n_past = ctypes .c_int (llama .n_tokens )
2932
+ n_past_p = ctypes .pointer (n_past )
2933
+ with suppress_stdout_stderr (disable = self .verbose ):
2934
+ self ._llava_cpp .llava_eval_image_embed (
2935
+ llama .ctx ,
2936
+ embed ,
2937
+ llama .n_batch ,
2938
+ n_past_p ,
2939
+ )
2940
+ # Required to avoid issues with hf tokenizer
2941
+ llama .input_ids [llama .n_tokens : n_past .value ] = - 1
2942
+ llama .n_tokens = n_past .value
2943
+
2941
2944
@staticmethod
2942
2945
def _load_image (image_url : str ) -> bytes :
2943
2946
# TODO: Add Pillow support for other image formats beyond (jpg, png)
@@ -3373,6 +3376,139 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler):
3373
3376
)
3374
3377
3375
3378
3379
+ class Gemma3ChatHandler (Llava15ChatHandler ):
3380
+ # Chat Format:
3381
+ # '<bos><start_of_turn>user\n{system_prompt}\n\n{prompt}<end_of_turn>\n<start_of_turn>model\n'
3382
+
3383
+ DEFAULT_SYSTEM_MESSAGE = None
3384
+
3385
+ CHAT_FORMAT = (
3386
+ "{{ '<bos>' }}"
3387
+ "{%- if messages[0]['role'] == 'system' -%}"
3388
+ "{%- if messages[0]['content'] is string -%}"
3389
+ "{%- set first_user_prefix = messages[0]['content'] + '\n \n ' -%}"
3390
+ "{%- else -%}"
3391
+ "{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n \n ' -%}"
3392
+ "{%- endif -%}"
3393
+ "{%- set loop_messages = messages[1:] -%}"
3394
+ "{%- else -%}"
3395
+ "{%- set first_user_prefix = \" \" -%}"
3396
+ "{%- set loop_messages = messages -%}"
3397
+ "{%- endif -%}"
3398
+ "{%- for message in loop_messages -%}"
3399
+ "{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}"
3400
+ "{{ raise_exception(\" Conversation roles must alternate user/assistant/user/assistant/...\" ) }}"
3401
+ "{%- endif -%}"
3402
+ "{%- if (messa
A3D4
ge['role'] == 'assistant') -%}"
3403
+ "{%- set role = \" model\" -%}"
3404
+ "{%- else -%}"
3405
+ "{%- set role = message['role'] -%}"
3406
+ "{%- endif -%}"
3407
+ "{{ '<start_of_turn>' + role + '\n ' + (first_user_prefix if loop.first else \" \" ) }}"
3408
+ "{%- if message['content'] is string -%}"
3409
+ "{{ message['content'] | trim }}"
3410
+ "{%- elif message['content'] is iterable -%}"
3411
+ "{%- for item in message['content'] -%}"
3412
+ "{%- if item['type'] == 'image_url' -%}"
3413
+ "{{ '<start_of_image>' }}"
3414
+ "{%- elif item['type'] == 'text' -%}"
3415
+ "{{ item['text'] | trim }}"
3416
+ "{%- endif -%}"
3417
+ "{%- endfor -%}"
3418
+ "{%- else -%}"
3419
+ "{{ raise_exception(\" Invalid content type\" ) }}"
3420
+ "{%- endif -%}"
3421
+ "{{ '<end_of_turn>\n ' }}"
3422
+ "{%- endfor -%}"
3423
+ "{%- if add_generation_prompt -%}"
3424
+ "{{ '<start_of_turn>model\n ' }}"
3425
+ "{%- endif -%}"
3426
+ )
3427
+
3428
+ @staticmethod
3429
+ def split_text_on_image_urls (text : str , image_urls : List [str ]):
3430
+ split_text : List [Tuple [Literal ["text" , "image_url" ], str ]] = []
3431
+ copied_urls = image_urls [:]
3432
+ remaining = text
3433
+ image_placeholder = "<start_of_image>"
3434
+
3435
+ while remaining :
3436
+ # Find placeholder
3437
+ pos = remaining .find (image_placeholder )
3438
+ if pos != - 1 :
3439
+ assert len (copied_urls ) > 0
3440
+ if pos > 0 :
3441
+ split_text .append (("text" , remaining [:pos ]))
3442
+ split_text .append (("text" , "\n \n <start_of_image>" ))
3443
+ split_text .append (("image_url" , copied_urls .pop (0 )))
3444
+ split_text .append (("text" , "<end_of_image>\n \n " ))
3445
+ remaining = remaining [pos + len (image_placeholder ):]
3446
+ else :
3447
+ assert len (copied_urls ) == 0
3448
+ split_text .append (("text" , remaining ))
3449
+ remaining = ""
3450
+ return split_text
3451
+
3452
+ def eval_image (self , llama : llama .Llama , image_url : str ):
3453
+ import llama_cpp
3454
+
3455
+ n_tokens = 256
3456
+ if llama .n_tokens + n_tokens > llama .n_ctx ():
3457
+ raise ValueError (
3458
+ f"Prompt exceeds n_ctx: { llama .n_tokens + n_tokens } > { llama .n_ctx ()} "
3459
+ )
3460
+
3461
+ img_bytes = self .load_image (image_url )
3462
+ img_u8_p = self ._llava_cpp .clip_image_u8_init ()
3463
+ if not self ._llava_cpp .clip_image_load_from_bytes (
3464
+ ctypes .create_string_buffer (img_bytes , len (img_bytes )),
3465
+ ctypes .c_size_t (len (img_bytes )),
3466
+ img_u8_p ,
3467
+ ):
3468
+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3469
+ raise ValueError ("Failed to load image." )
3470
+
3471
+ img_f32 = self ._llava_cpp .clip_image_f32_batch ()
3472
+ img_f32_p = ctypes .byref (img_f32 )
3473
+ if not self ._llava_cpp .clip_image_preprocess (self .clip_ctx , img_u8_p , img_f32_p ):
3474
+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3475
+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3476
+ raise ValueError ("Failed to preprocess image." )
3477
+
3478
+ n_embd = llama_cpp .llama_model_n_embd (llama ._model .model )
3479
+ embed = (ctypes .c_float * (n_tokens * n_embd ))()
3480
+ if not self ._llava_cpp .clip_image_batch_encode (self .clip_ctx , llama .n_threads , img_f32_p , embed ):
3481
+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3482
+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3483
+ raise ValueError ("Failed to encode image." )
3484
+
3485
+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3486
+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3487
+ llama_cpp .llama_set_causal_attn (llama .ctx , False )
3488
+
3489
+ seq_id_0 = (ctypes .c_int32 * 1 )()
3490
+ seq_ids = (ctypes .POINTER (ctypes .c_int32 ) * (n_tokens + 1 ))()
3491
+ for i in range (n_tokens ):
3492
+ seq_ids [i ] = seq_id_0
3493
+
3494
+ batch = llama_cpp .llama_batch ()
3495
+ batch .n_tokens = n_tokens
3496
+ batch .token = None
3497
+ batch .embd = embed
3498
+ batch .pos = (ctypes .c_int32 * n_tokens )(* [i + llama .n_tokens for i in range (n_tokens )])
3499
+ batch .seq_id = seq_ids
3500
+ batch .n_seq_id = (ctypes .c_int32 * n_tokens )(* ([1 ] * n_tokens ))
3501
+ batch .logits = (ctypes .c_int8 * n_tokens )()
3502
+
3503
+ if llama_cpp .llama_decode (llama .ctx , batch ):
3504
+ raise ValueError ("Failed to decode image." )
3505
+
3506
+ llama_cpp .llama_set_causal_attn (llama .ctx , True )
3507
+ # Required to avoid issues with hf tokenizer
3508
+ llama .input_ids [llama .n_tokens : llama .n_tokens + n_tokens ] = - 1
3509
+ llama .n_tokens += n_tokens
3510
+
3511
+
3376
3512
@register_chat_completion_handler ("chatml-function-calling" )
3377
3513
def chatml_function_calling (
3378
3514
llama : llama .Llama ,
0 commit comments