8000 mtmd : add methods to access `mtmd_image_tokens` (#12906) · ggml-org/llama.cpp@b9154ec · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b9154ec

Browse files
authored
mtmd : add methods to access mtmd_image_tokens (#12906)
* mtmd : add more api around mtmd_image_tokens * mtmd : ability to calc image hash * shared_ptr for mtmd_image_tokens * move hash to user-define ID (fixed) * fix prompt_modified * rm redundant data member
1 parent 2db9ba1 commit b9154ec

File tree

3 files changed

+92
-44
lines changed

3 files changed

+92
-44
lines changed

examples/llava/gemma3-cli.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,18 +184,19 @@ static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector
184184
text.text = formatted_chat.prompt;
185185
text.add_special = add_bos;
186186
text.parse_special = true;
187-
mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps));
188-
if (chunks == nullptr) {
189-
LOG_ERR("Unable to tokenize prompt\n");
187+
mtmd_input_chunks chunks;
188+
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
189+
if (res != 0) {
190+
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
190191
return 1;
191192
}
192193

193-
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) {
194+
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
194195
LOG_ERR("Unable to eval prompt\n");
195196
return 1;
196197
}
197198

198-
ctx.n_past += mtmd_helper_get_n_tokens(chunks.get());
199+
ctx.n_past += mtmd_helper_get_n_tokens(chunks);
199200

200201
return 0;
20 10000 1202
}

examples/llava/mtmd.cpp

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct mtmd_context {
1616
struct clip_ctx * ctx_clip;
1717
const struct llama_model * text_model;
1818
std::vector<float> image_embd_v; // image embedding vector
19+
1920
bool print_timings;
2021
int n_threads;
2122
std::string image_marker;
@@ -24,7 +25,11 @@ struct mtmd_context {
2425

2526
mtmd_context(const char * mmproj_fname,
2627
const llama_model * text_model,
27-
const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) {
28+
const mtmd_context_params & ctx_params) :
29+
print_timings(ctx_params.print_timings),
30+
n_threads (ctx_params.n_threads),
31+
image_marker (ctx_params.image_marker)
32+
{
2833
clip_context_params ctx_clip_params;
2934
ctx_clip_params.use_gpu = ctx_params.use_gpu;
3035
ctx_clip_params.verbosity = ctx_params.verbosity;
@@ -49,6 +54,7 @@ struct mtmd_image_tokens {
4954
uint32_t ny; // number of tokens in y direction
5055
uint32_t n_tokens() const { return nx * ny; }
5156
clip_image_f32_batch batch_f32; // preprocessed image patches
57+
std::string id; // optional user-defined ID, useful for KV cache tracking
5258
};
5359

5460
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
@@ -88,10 +94,10 @@ static std::vector<llama_token> mtmd_tokenize_text_internal(
8894
return result;
8995
}
9096

91-
mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
92-
const mtmd_input_text & text,
93-
const std::vector<mtmd_bitmap> & bitmaps) {
94-
mtmd_input_chunks * output = new mtmd_input_chunks;
97+
int32_t mtmd_tokenize(mtmd_context * ctx,
98+
std::vector<mtmd_input_chunk> & output,
99+
const mtmd_input_text & text,
100+
const std::vector<mtmd_bitmap> & bitmaps) {
95101
auto vocab = llama_model_get_vocab(ctx->text_model);
96102

97103
std::string prompt_modified(text.text);
@@ -105,9 +111,9 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
105111
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
106112
}
107113

108-
std::vector<std::string> parts = string_split_str(text.text, ctx->image_marker);
109-
output->clear();
110-
output->reserve(parts.size());
114+
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
115+
output.clear();
116+
output.reserve(parts.size());
111117

112118
size_t i_img = 0;
113119

@@ -123,14 +129,14 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
123129
std::move(tokens),
124130
{},
125131
};
126-
output->emplace_back(std::move(chunk));
132+
output.emplace_back(std::move(chunk));
127133

128134
if (&parts.back() != &part) {
129135
// add image token to middle of 2 parts
130136

131137
if (i_img >= bitmaps.size()) {
132138
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
133-
return nullptr;
139+
return 1;
134140
}
135141

136142
// shim layer
@@ -145,34 +151,48 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
145151
bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32);
146152
if (!ok) {
147153
LOG_ERR("Unable to preprocess image\n");
148-
return nullptr;
154+
return 2;
149155
}
150156

151-
mtmd_image_tokens * image_tokens = new mtmd_image_tokens;
157+
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
152158
image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
153159
image_tokens->ny = 1; // TODO
154160
image_tokens->batch_f32 = std::move(batch_f32);
161+
image_tokens->id = bitmaps[i_img].id; // optional
155162

156163
mtmd_input_chunk chunk{
157164
MTMD_INPUT_CHUNK_TYPE_IMAGE,
158165
{},
159-
image_tokens,
166+
std::move(image_tokens),
160167
};
161-
output->emplace_back(std::move(chunk));
168+
output.emplace_back(std::move(chunk));
162169
i_img++;
163170
}
164171
}
165172

166-
return output;
173+
return 0;
167174
}
168175

169-
void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
170-
for (auto & chunk : *chunks) {
171-
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
172-
delete chunk.tokens_image;
173-
}
176+
void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
177+
if (image_tokens) {
178+
delete image_tokens;
174179
}
175-
delete chunks;
180+
}
181+
182+
size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
183+
return image_tokens->n_tokens();
184+
}
185+
186+
size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
187+
return image_tokens->nx;
188+
}
189+
190+
size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
191+
return image_tokens->ny;
192+
}
193+
194+
std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
195+
return image_tokens->id;
176196
}
177197

178 10000 198
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
@@ -190,9 +210,9 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
190210
return ctx->image_embd_v.data();
191211
}
192212

193-
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) {
213+
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
194214
size_t n_tokens = 0;
195-
for (auto & chunk : *chunks) {
215+
for (auto & chunk : chunks) {
196216
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
197217
n_tokens += chunk.tokens_text.size();
198218
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
@@ -241,16 +261,16 @@ struct decode_embd_batch {
241261

242262
int32_t mtmd_helper_eval(mtmd_context * ctx,
243263
llama_context * lctx,
244-
mtmd_input_chunks * chunks,
264+
mtmd_input_chunks & chunks,
245265
llama_pos pos0,
246266
llama_seq_id seq_id,
247267
int32_t n_batch) {
248268
int32_t ret;
249269
llama_pos n_past = pos0;
250270
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
251271

252-
for (auto & chunk : *chunks) {
253-
bool is_last = &chunk == &chunks->back();
272+
for (auto & chunk : chunks) {
273+
bool is_last = &chunk == &chunks.back();
254274
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
255275
// TODO @ngxson : may need to split into smaller batches
256276
text_batch.n_tokens = chunk.tokens_text.size();
@@ -279,7 +299,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
279299
if (ctx->print_timings) {
280300
LOG_INF("encoding image...\n");
281301
}
282-
ret = mtmd_encode(ctx, chunk.tokens_image);
302+
ret = mtmd_encode(ctx, chunk.tokens_image.get());
283303
if (ret != 0) {
284304
LOG_ERR("failed to encode image\n");
285305
llama_batch_free(text_batch);
@@ -289,7 +309,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
289309
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
290310
}
291311

292-
int32_t n_tokens = chunk.tokens_image->n_tokens();
312+
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
293313
float * embd = mtmd_get_output_embd(ctx);
294314
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
295315
int64_t t1 = ggml_time_ms();
@@ -339,3 +359,15 @@ int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & outp
339359
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
340360
return 0;
341361
}
362+
363+
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
364+
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
365+
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
366+
return true;
367+
}
368+
return false;
369+
}
370+
371+
void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
372+
mtmd_image_tokens_free(val);
373+
}

examples/llava/mtmd.h

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,18 @@ struct mtmd_bitmap {
3939
uint32_t nx;
4040
uint32_t ny;
4141
std::vector<unsigned char> data;
42+
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
4243
};
4344

45+
struct mtmd_image_tokens_deleter {
46+
void operator()(mtmd_image_tokens * val); // forward declaration
47+
};
48+
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
49+
4450
struct mtmd_input_chunk {
4551
mtmd_input_chunk_type type;
4652
std::vector<llama_token> tokens_text;
47-
mtmd_image_tokens * tokens_image = nullptr;
53+
mtmd_image_tokens_ptr tokens_image;
4854
};
4955

5056
using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
@@ -82,12 +88,21 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
8288
// 3. "<end_of_image>\ndescribe it in detail."
8389
// number of bitmaps must be equal to the number of image markers in the prompt
8490
// this function is thread-safe (shared ctx)
85-
MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
91+
// return values:
92+
// 0 on success
93+
// 1 on number of images not matching the number of markers
94+
// 2 on image preprocessing error
95+
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
96+
std::vector<mtmd_input_chunk> & output,
8697
const mtmd_input_text & text,
8798
const std::vector<mtmd_bitmap> & bitmaps);
8899

89-
// free image chunk data
90-
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
100+
// access mtmd_image_tokens
101+
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
102+
MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
103+
MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens);
104+
MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens);
105+
MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);
91106

92107
// returns 0 on success
93108
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
@@ -96,12 +111,17 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
96111
// get output embeddings from the last encode pass
97112
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
98113

114+
// whether we need to set non-causal mask before llama_decode
115+
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
116+
117+
118+
99119
//
100120
// helper functions (can be implemented based on other functions)
101121
//
102122

103123
// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
104-
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
124+
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks);
105125

106126
// helper function that automatically:
107127
// 1. run llama_decode() on text chunks
@@ -110,7 +130,7 @@ MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
110130
// otherwise, returns 0 on success
111131
MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx,
112132
llama_context * lctx,
113-
mtmd_input_chunks * chunks,
133+
mtmd_input_chunks & chunks,
114134
llama_pos pos0,
115135
llama_seq_id seq_id,
116136
int32_t n_batch);
@@ -132,11 +152,6 @@ struct mtmd_context_deleter {
132152
};
133153
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
134154

135-
struct mtmd_input_chunks_deleter {
136-
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
137-
};
138-
using mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;
139-
140155
#else
141156

142157
static_assert(false && "C header is not yet supported by this library");

0 commit comments

Comments
 (0)
0