@@ -2603,7 +2603,12 @@ def __call__(
2603
2603
2604
2604
image_urls = self .get_image_urls (messages )
2605
2605
template = jinja2 .Template (self .CHAT_FORMAT )
2606
- text = template .render (messages = messages , add_generation_prompt = True )
2606
+ text = template .render (
2607
+ messages = messages ,
2608
+ add_generation_prompt = True ,
2609
+ eos_token = llama .detokenize ([llama .token_eos ()]),
2610
+ bos_token = llama .detokenize ([llama .token_bos ()]),
2611
+ )
2607
2612
split_text = self .split_text_on_image_urls (text , image_urls )
2608
2613
2609
2614
def embed_image_bytes (image_bytes : bytes ):
@@ -2624,9 +2629,9 @@ def embed_image_bytes(image_bytes: bytes):
2624
2629
2625
2630
# Evaluate prompt
2626
2631
llama .reset ()
2627
- for i , ( type_ , value ) in enumerate ( split_text ) :
2632
+ for type_ , value in split_text :
2628
2633
if type_ == "text" :
2629
- tokens = llama .tokenize (value .encode ("utf8" ), add_bos = i == 0 )
2634
+ tokens = llama .tokenize (value .encode ("utf8" ), add_bos = False , special = True )
2630
2635
if llama .n_tokens + len (tokens ) > llama .n_ctx ():
2631
2636
raise ValueError ("Prompt exceeds n_ctx" ) # TODO: Fix
2632
2637
llama .eval (tokens )
@@ -2644,6 +2649,8 @@ def embed_image_bytes(image_bytes: bytes):
2644
2649
llama .n_batch ,
2645
2650
n_past_p ,
2646
2651
)
2652
+ # Required to avoid issues with hf tokenizer
2653
+ llama .input_ids [llama .n_tokens : n_past .value ] = - 1
2647
2654
llama .n_tokens = n_past .value
2648
2655
2649
2656
# Get prompt tokens to avoid a cache miss
@@ -3033,6 +3040,7 @@ class NanoLlavaChatHandler(Llava15ChatHandler):
3033
3040
# Answer the question<|im_end|><|im_start|>user
3034
3041
# <image>
3035
3042
# What is the picture about?<|im_end|><|im_start|>assistant
3043
+ DEFAULT_SYSTEM_MESSAGE = "Answer the question"
3036
3044
3037
3045
CHAT_FORMAT = (
3038
3046
"{% for message in messages %}"
0 commit comments