8000 mtmd : Expose helper_decode_image_chunk (#13366) · ggml-org/llama.cpp@f05a6d7 · GitHub
[go: up one dir, main page]

Skip to content

Commit f05a6d7

Browse files
authored
mtmd : Expose helper_decode_image_chunk (#13366)
* mtmd: Expose helper_decode_image, output_embd_copy, image_tokens_copy/free * Slim down * Cleanups
1 parent ee01d71 commit f05a6d7

File tree

2 files changed

+90
-51
lines changed

2 files changed

+90
-51
lines changed

tools/mtmd/mtmd.cpp

Lines changed: 78 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,79 @@ struct decode_embd_batch {
580580
}
581581
};
582582

583+
// Helper function for decoding an image whose embeddings have already been calculated
584+
int32_t mtmd_helper_decode_image_chunk(
585+
mtmd_context * ctx,
586+
struct llama_context * lctx,
587+
const mtmd_input_chunk * chunk,
588+
float * encoded_embd,
589+
llama_pos n_past,
590+
llama_seq_id seq_id,
591+
int32_t n_batch,
592+
llama_pos * new_n_past) {
593+
if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
594+
LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
595+
return -1;
596+
}
597+
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
598+
if (!image_tokens) {
599+
LOG_ERR("failed to decode image chunk: image tokens are null\n");
600+
return -1;
601+
}
602+
603+
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
604+
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
605+
606+
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
607+
int32_t i_batch = 0;
608+
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
609+
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
610+
611+
const int nx = mtmd_image_tokens_get_nx(image_tokens);
612+
const int ny = mtmd_image_tokens_get_ny(image_tokens);
613+
614+
if (mtmd_decode_use_mrope(ctx)) {
615+
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
616+
} else {
617+
batch_embd.set_position_normal(n_past, seq_id);
618+
}
619+
620+
if (mtmd_decode_use_non_causal(ctx)) {
621+
llama_set_causal_attn(lctx, false);
622+
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
623+
}
624+
625+
while (i_batch < n_img_batches) { // split into batches
626+
int pos_offset = i_batch*n_batch;
627+
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
628+
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
629+
630+
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
631+
632+
int64_t t1 = ggml_time_ms();
633+
int32_t ret = llama_decode(lctx, batch_embd_view);
634+
if (ret != 0) {
635+
LOG_ERR("failed to decode image\n");
636+
llama_set_causal_attn(lctx, true); // restore causal attn
637+
return ret;
638+
}
639+
640+
if (ctx->print_timings) {
641+
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
642+
}
643+
644+
i_batch++;
645+
}
646+
647+
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
648+
*new_n_past = n_past;
649+
650+
if (mtmd_decode_use_non_causal(ctx)) {
651+
llama_set_causal_attn(lctx, true);
652+
}
653+
return 0;
654+
}
655+
583656
int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
584657
struct llama_context * lctx,
585658
const mtmd_input_chunk * chunk,
@@ -591,8 +664,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
591664
int32_t ret;
592665
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
593666
auto chunk_type = mtmd_input_chunk_get_type(chunk);
594-
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
595-
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
596667

597668
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
598669
size_t n_tokens;
@@ -637,57 +708,13 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
637708
if (ctx->print_timings) {
638709
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
639710
}
640-
641-
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
642-
int32_t i_batch = 0;
643-
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
644711
float * embd = mtmd_get_output_embd(ctx);
645-
decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
646-
647-
const int nx = mtmd_image_tokens_get_nx(image_tokens);
648-
const int ny = mtmd_image_tokens_get_ny(image_tokens);
649-
650-
if (mtmd_decode_use_mrope(ctx)) {
651-
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
652-
} else {
653-
batch_embd.set_position_normal(n_past, seq_id);
654-
}
655-
656-
if (mtmd_decode_use_non_causal(ctx)) {
657-
llama_set_causal_attn(lctx, false);
658-
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
659-
}
660-
661-
while (i_batch < n_img_batches) { // split into batches
662-
int pos_offset = i_batch*n_batch;
663-
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
664-
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
665-
666-
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
667-
668-
int64_t t1 = ggml_time_ms();
669-
ret = llama_decode(lctx, batch_embd_view);
670-
if (ret != 0) {
671-
LOG_ERR("failed to decode image\n");
672-
llama_set_causal_attn(lctx, true); // restore causal attn
673-
llama_batch_free(text_batch);
674-
return ret;
675-
}
676-
677-
if (ctx->print_timings) {
678-
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
679-
}
680-
681-
i_batch++;
682-
}
683-
684-
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
685-
*new_n_past = n_past;
686-
687-
if (mtmd_decode_use_non_causal(ctx)) {
688-
llama_set_causal_attn(lctx, true);
712+
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
713+
if (ret != 0) {
714+
LOG_ERR("failed to decode image\n");
715+
llama_batch_free(text_batch);
716+
return ret;
689717
}
690-
691718
} else {
692719
GGML_ABORT("chunk type not supported");
693720
}

tools/mtmd/mtmd.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,18 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
231231
bool logits_last,
232232
llama_pos * new_n_past);
233233

234+
// helper function to decode an image whose embeddings have already been calculated
235+
// this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention)
236+
// ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure
237+
MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
238+
struct llama_context * lctx,
239+
const mtmd_input_chunk * chunk,
240+
float * encoded_embd,
241+
llama_pos n_past,
242+
llama_seq_id seq_id,
243+
int32_t n_batch,
244+
llama_pos * new_n_past);
245+
234246
/////////////////////////////////////////
235247

236248
// test function, to be used in test-mtmd-c-api.c

0 commit comments

Comments
 (0)
0