8000 convert : experimental support for `--mmproj` flag · ggml-org/llama.cpp@d5e03e6 · GitHub
[go: up one dir, main page]

Skip to content

Commit d5e03e6

Browse files
committed
convert : experimental support for --mmproj flag
1 parent 37b9f0d commit d5e03e6

File tree

4 files changed

+472
-43
lines changed

4 files changed

+472
-43
lines changed

convert_hf_to_gguf.py

Lines changed: 194 additions & 38 deletions
< 10000 tr class="diff-line-row">
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,20 @@ class Model:
6767
dir_model_card: Path
6868
remote_hf_model_id: str | None
6969

70+
# for vision encoders
71+
mmproj: bool
72+
ignore_vision: bool = False # subclasses may overwrite this
73+
mtmd_model: MultimodalModel | None = None
74+
7075
# subclasses should define this!
7176
model_arch: gguf.MODEL_ARCH
7277

7378
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
7479
use_temp_file: bool = False, eager: bool = False,
7580
metadata_override: Path | None = None, model_name: str | None = None,
7681
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
77-
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
82+
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
83+
mmproj: bool = False):
7884
if type(self) is Model:
7985
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
8086

@@ -109,6 +115,7 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
109115
self.metadata_override = metadata_override
110116
self.model_name = model_name
111117
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
118+
self.mmproj = mmproj
112119

113120
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
114121
if self.ftype == gguf.LlamaFileType.GUESSED:
@@ -125,6 +132,28 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
125132
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
126133
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
127134

135+
# vision encoder
136+
if mmproj:
137+
vision_hparams = self.hparams.get("vision_config")
138+
if vision_hparams is None:
139+
raise ValueError("Vision config not found in model config")
140+
elif self.ignore_vision:
141+
raise ValueError("Vision config found, but mmproj conversion for this model is not supported yet")
142+
else:
143+
self.mtmd_model = MultimodalModel(
144+
hparams=vision_hparams,
145+
ftype=self.ftype,
146+
fname_out=self.fname_out,
147+
endianess=self.endianess,
148+
use_temp_file=self.use_temp_file,
149+
)
150+
151+
@classmethod
152+
def add_prefix_to_filename(cls, path: Path, prefix: str) -> Path:
153+
stem, suffix = path.stem, path.suffix
154+
new_name = f"{prefix}{stem}{suffix}"
155+
return path.with_name(new_name)
156+
128157
@classmethod
129158
def __init_subclass__(cls):
130159
# can't use an abstract property, because overriding it without type errors
@@ -272,8 +301,13 @@ def set_gguf_parameters(self):
272301
self.gguf_writer.add_key_length(head_dim)
273302
self.gguf_writer.add_value_length(head_dim)
274303

275-
self.gguf_writer.add_file_type(self.ftype)
276-
logger.info(f"gguf: file type = {self.ftype}")
304+
if not self.mmproj:
305+
self.gguf_writer.add_file_type(self.ftype)
306+
logger.info(f"gguf: file type = {self.ftype}")
307+
else:
308+
assert self.mtmd_model is not None
309+
self.mtmd_model.set_gguf_parameters(n_embd_text=n_embd)
310+
logger.info(f"mmproj: file type = {self.mtmd_model.ftype}")
277311

278312
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
279313
del bid # unused
@@ -311,6 +345,10 @@ def prepare_tensors(self):
311345
break
312346

313347
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
348+
# skip adding tensor if we're working with a vision model
349+
if self.mmproj:
350+
continue
351+
314352
# TODO: why do we squeeze here?
315353
# data = data_torch.squeeze().numpy()
316354
data = data_torch.numpy()
@@ -455,12 +493,18 @@ def prepare_metadata(self, vocab_only: bool):
455493
self.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
456494

457495
def write(self):
458-
self.prepare_tensors()
459-
self.prepare_metadata(vocab_only=False)
460-
self.gguf_writer.write_header_to_file(path=self.fname_out)
461-
self.gguf_writer.write_kv_data_to_file()
462-
self.gguf_writer.write_tensors_to_file(progress=True)
463-
self.gguf_writer.close()
496+
if self.mtmd_model is not None:
497+
self.prepare_tensors()
498+
self.prepare_metadata(vocab_only=False)
499+
logger.info("Writing vision model")
500+
self.mtmd_model.write()
501+
else:
502+
self.prepare_tensors()
503+
self.prepare_metadata(vocab_only=False)
504+
self.gguf_writer.write_header_to_file(path=self.fname_out)
505+
self.gguf_writer.write_kv_data_to_file()
506+
self.gguf_writer.write_tensors_to_file(progress=True)
507+
self.gguf_writer.close()
464508

465509
def write_vocab(self):
466510
if len(self.gguf_writer.tensors) != 1:
@@ -485,7 +529,10 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
485529
@staticmethod
486530
def load_hparams(dir_model: Path):
487531
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
488-
return json.load(f)
532+
hparams = json.load(f)
533+
if "text_config" in hparams:
534+
hparams = {**hparams, **hparams["text_config"]}
535+
return hparams
489536

490537
@classmethod
491538
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@@ -1024,6 +1071,101 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab
10241071
self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])
10251072

10261073

1074+
# for converting mmproj file
1075+
class MultimodalModel:
1076+
hparams: dict
1077+
dir_model: Path
1078+
ftype: gguf.LlamaFileType
1079+
fname_out: Path
1080+
tensor_map: gguf.TensorNameMap
1081+
gguf_writer: gguf.GGUFWriter
1082+
1083+
def __init__(self, hparams: dict, ftype: gguf.LlamaFileType, fname_out: Path, endianess: gguf.GGUFEndian, use_temp_file: bool):
1084+
self.hparams = hparams
1085+
self.ftype = ftype
1086+
self.fname_out = fname_out
1087+
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, 128)
1088+
self.gguf_writer = gguf.GGUFWriter(path=None,
1089+
arch="clip",
1090+
endianess=endianess,
1091+
use_temp_file=use_temp_file)
1092+
1093+
def set_gguf_parameters(self, n_embd_text: int):
1094+
"""Function to be called by Model.set_gguf_parameters()"""
1095+
self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
1096+
self.gguf_writer.add_file_type(self.ftype)
1097+
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PROJECTION_DIM, n_embd_text)
1098+
self.gguf_writer.add_bool(gguf.Keys.ClipVision.HAS_VISION_ENCODER, True)
1099+
1100+
# vision config
1101+
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.IMAGE_SIZE, self.find_hparam(["image_size"]))
1102+
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PATCH_SIZE, self.find_hparam(["patch_size"]))
1103+
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.EMBEDDING_LENGTH, self.find_hparam(["hidden_size"]))
1104+
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.FEED_FORWARD_LENGTH, self.find_hparam(["intermediate_size"]))
1105+
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.BLOCK_COUNT, self.find_hparam(["num_hidden_layers"]))
1106+
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.Attention.HEAD_COUNT, self.find_hparam(["num_attention_heads"]))
1107+
1108+
def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
1109+
key = next((k for k in keys if k in self.hparams), None)
1110+
if key is not None:
1111+
return self.hparams[key]
1112+
if optional:
1113+
return None
1114+
raise KeyError(f"could not find any of: {keys}")
1115+
1116+
def get_quantization(self, mapped_name: str, data_torch: Tensor) -> gguf.GGMLQuantizationType:
1117+
is_1d = len(data_torch.shape) == 1
1118+
is_embd = "_embd" in mapped_name
1119+
can_quantize = not is_1d and not is_embd
1120+
data_qtype = gguf.GGMLQuantizationType.F32
1121+
if can_quantize:
1122+
if self.ftype == gguf.LlamaFileType.ALL_F32:
1123+
data_qtype = gguf.GGMLQuantizationType.F32
1124+
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
1125+
data_qtype = gguf.GGMLQuantizationType.F16
1126+
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
1127+
data_qtype = gguf.GGMLQuantizationType.BF16
1128+
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
1129+
data_qtype = gguf.GGMLQuantizationType.Q8_0
1130+
else:
1131+
raise ValueError(f"Unsupported file type: {self.ftype}")
1132+
return data_qtype
1133+
1134+
def add_tensor(self, original_name: str, data_torch: Tensor) -> None:
1135+
"""Function to be called inside Model.modify_tensors()"""
1136+
# name mapping
1137+
new_name = self.tensor_map.get_name(key=original_name, try_suffixes=(".weight", ".bias"))
1138+
if new_name is None:
1139+
raise ValueError(f"Can not map tensor {original_name!r}")
1140+
1141+
# process data
1142+
# old_dtype = data_torch.dtype
1143+
data_qtype = self.get_quantization(new_name, data_torch)
1144+
data = data_torch.numpy()
1145+
try:
1146+
data = gguf.quants.quantize(data, data_qtype)
1147+
except Exception as e:
1148+
logger.error(f"Error quantizing tensor '{new_name}': {e}, fallback to F16")
1149+
data_qtype = gguf.GGMLQuantizationType.F16
1150+
data = gguf.quants.quantize(data, data_qtype)
1151+
1152+
# reverse shape to make it similar to the internal ggml dimension order
1153+
# TODO: we don't print old_dtype because it's not correct, to be fixed later
1154+
old_dtype = ""
1155+
shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}"
1156+
logger.info(f"{f'%-32s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
1157+
1158+
# add tensor
1159+
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
1160+
1161+
def write(self):
1162+
"""Function to be called by Model.write()"""
1163+
self.gguf_writer.write_header_to_file(path=self.fname_out)
1164+
self.gguf_writer.write_kv_data_to_file()
1165+
self.gguf_writer.write_tensors_to_file(progress=True)
1166+
self.gguf_writer.close()
1167+
1168+
10271169
@Model.register("GPTNeoXForCausalLM")
10281170
class GPTNeoXModel(Model):
10291171
model_arch = gguf.MODEL_ARCH.GPTNEOX
@@ -1781,20 +1923,13 @@ def prepare_tensors(self):
17811923
@Model.register("Llama4ForConditionalGeneration")
17821924
class Llama4Model(LlamaModel):
17831925
model_arch = gguf.MODEL_ARCH.LLAMA4
1784-
has_vision: bool = False
17851926
undo_permute = False
1927+
ignore_vision = True
17861928

17871929
# TODO @ngxson : avoid duplicate this code everywhere by at least support "text_config"
17881930
# same with llama, but we need to merge the text_config into the root level of hparams
17891931
def __init__(self, *args, **kwargs):
1790-
hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
1791-
if "text_config" in hparams:
1792-
hparams = {**hparams, **hparams["text_config"]}
1793-
kwargs["hparams"] = hparams
17941932
super().__init__(*args, **kwargs)
1795-
if "vision_config" in hparams:
1796-
logger.info("Has vision encoder, but it will be ignored")
1797-
self.has_vision = True
17981933
# IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this
17991934
self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"]
18001935
self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"]
@@ -1824,7 +1959,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
18241959
name += ".weight"
18251960
data_torch = data_torch.transpose(-1, -2)
18261961

1827-
if "multi_modal_projector" in name or "vision_model" in name:
1962+
if "multi_modal_projector" in name or "mtmd_model" in name:
18281963
return []
18291964
return super().modify_tensors(data_torch, name, bid)
18301965

@@ -3474,24 +3609,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
34743609
@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
34753610
class Gemma3Model(Model):
34763611
model_arch = gguf.MODEL_ARCH.GEMMA3
3477-
has_vision: bool = False
3478-
3479-
# we need to merge the text_config into the root level of hparams
3480-
def __init__(self, *args, **kwargs):
3481-
hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
3482-
if "text_config" in hparams:
3483-
hparams = {**hparams, **hparams["text_config"]}
3484-
kwargs["hparams"] = hparams
3485-
super().__init__(*args, **kwargs)
3486-
if "vision_config" in hparams:
3487-
logger.info("Has vision encoder, but it will be ignored")
3488-
self.has_vision = True
34893612

34903613
def write(self):
34913614
super().write()
3492-
if self.has_vision:
3493-
logger.info("NOTE: this script only convert the language model to GGUF")
3494-
logger.info(" for the vision model, please use gemma3_convert_encoder_to_gguf.py")
34953615

34963616
def set_vocab(self):
34973617
self._set_vocab_sentencepiece()
@@ -3524,15 +3644,42 @@ def set_gguf_parameters(self):
35243644
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
35253645
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
35263646

3647+
if self.mtmd_model is not None:
3648+
self.mtmd_model.set_gguf_parameters(n_embd_text=hparams["hidden_size"])
3649+
vgguf = self.mtmd_model.gguf_writer
3650+
vgguf.add_string(gguf.Keys.ClipVision.PROJECTOR_TYPE, "gemma3")
3651+
# default values below are taken from HF tranformers code
3652+
vgguf.add_float32(gguf.Keys.ClipVision.Attention.LAYERNORM_EPS, self.mtmd_model.hparams.get("layer_norm_eps", 1e-6))
3653+
vgguf.add_array(gguf.Keys.ClipVision.IMAGE_MEAN, [0.5, 0.5, 0.5])
3654+
vgguf.add_array(gguf.Keys.ClipVision.IMAGE_STD, [0.5, 0.5, 0.5])
3655+
vgguf.add_bool (gguf.Keys.ClipVision.USE_GELU, True)
3656+
35273657
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
35283658
del bid # unused
35293659

35303660
if name.startswith("language_model."):
35313661
name = name.replace("language_model.", "")
3662+
35323663
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
3533-
or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later
3534-
# ignore vision tensors
3535-
return []
3664+
or name.startswith("multimodal_projector.") or name.startswith("mtmd_model."):
3665+
if self.mmproj:
3666+
assert self.mtmd_model is not None
3667+
# process vision tensors
3668+
name = name.replace("_weight", ".weight")
3669+
if "fc1" in name:
3670+
name = name.replace("fc1", "fc2")
3671+
else:
3672+
name = name.replace("fc2", "fc1")
3673+
3674+
# corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
3675+
# the other norm values are part of SigLIP model, and they are already correct
3676+
# ref code: Gemma3RMSNorm
3677+
if "soft_emb_norm.weight" in name:
3678+
logger.info(f"Correcting norm value for '{name}'")
3679+
data_torch = data_torch + 1
3680+
3681+
self.mtmd_model.add_tensor(name, data_torch)
3682+
return [] # vision tensor already handled
35363683

35373684
# remove OOV (out-of-vocabulary) rows in token_embd
35383685
if "embed_tokens.weight" in name:
@@ -5554,6 +5701,10 @@ def parse_args() -> argparse.Namespace:
55545701
"--remote", action="store_true",
55555702
help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.",
55565703
)
5704+
parser.add_argument(
5705+
"--mmproj", action="store_true",
5706+
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
5707+
)
55575708

55585709
args = parser.parse_args()
55595710
if not args.print_supported_models and args.model is None:
@@ -5633,6 +5784,10 @@ def main() -> None:
56335784

56345785
hparams = Model.load_hparams(dir_model)
56355786

5787+
if args.mmproj:
5788+
if "mmproj" not in fname_out.name:
5789+
fname_out = Model.add_prefix_to_filename(fname_out, "mmproj-")
5790+
56365791
with torch.inference_mode():
56375792
output_type = ftype_map[args.outtype]
56385793
model_architecture = hparams["architectures"][0]
@@ -5649,7 +5804,8 @@ def main() -> None:
56495804
split_max_tensors=args.split_max_tensors,
56505805
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
56515806
small_first_shard=args.no_tensor_first_split,
5652-
remote_hf_model_id=str(args.model) if args.remote else None)
5807+
remote_hf_model_id=str(args.model) if args.remote else None,
5808+
mmproj=args.mmproj)
56535809

56545810
if args.vocab_only:
56555811
logger.info("Exporting model vocab...")

examples/llava/clip-impl.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
// tensor name constants
5151
//
5252

53-
#define TN_TOKEN_EMBD "%s.token_embd.weight"
5453
#define TN_POS_EMBD "%s.position_embd.weight"
5554
#define TN_CLASS_EMBD "v.class_embd"
5655
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
@@ -66,8 +65,6 @@
6665
#define TN_LN_2 "%s.blk.%d.ln2.%s"
6766
#define TN_LN_PRE "%s.pre_ln.%s"
6867
#define TN_LN_POST "%s.post_ln.%s"
69-
#define TN_TEXT_PROJ "text_projection.weight"
70-
#define TN_VIS_PROJ "visual_projection.weight"
7168
#define TN_LLAVA_PROJ "mm.%d.%s"
7269
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
7370
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"

0 commit comments

Comments
 (0)
0