8000 llama : add support for jina-reranker-v2 (#13900) · ggml-org/llama.cpp@e83ba3e · GitHub
[go: up one dir, main page]

Skip to content

Commit e83ba3e

Browse files
authored
llama : add support for jina-reranker-v2 (#13900)
1 parent 2b13162 commit e83ba3e

File tree

5 files changed

+119
-72
lines changed

5 files changed

+119
-72
lines changed

convert_hf_to_gguf.py

Lines changed: 85 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3782,44 +3782,93 @@ def _xlmroberta_set_vocab(self) -> None:
37823782
from sentencepiece import sentencepiece_model_pb2 as model
37833783

37843784
tokenizer_path = self.dir_model / 'sentencepiece.bpe.model'
3785+
3786+
tokenizer_json = {}
3787+
tokenizer_config_json = {}
37853788
if not tokenizer_path.is_file():
3786-
raise FileNotFoundError(f"File not found: {tokenizer_path}")
3789+
tokenizer_path = self.dir_model / 'tokenizer.json'
3790+
tokenizer_config_path = self.dir_model / 'tokenizer_config.json'
37873791

3788-
sentencepiece_model = model.ModelP 8000 roto() # pyright: ignore[reportAttributeAccessIssue]
3789-
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
3790-
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
3792+
if not tokenizer_path.is_file():
3793+
raise FileNotFoundError(f"File not found: {tokenizer_path}")
37913794

3792-
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
3793-
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
3794-
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
3795+
from base64 import b64decode
3796+
from transformers import AutoTokenizer
3797+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
37953798

3796-
tokenizer = SentencePieceProcessor()
3797-
tokenizer.LoadFromFile(str(tokenizer_path))
3799+
with open(tokenizer_path, "r", encoding="utf-8") as fp:
3800+
tokenizer_json = json.load(fp)
37983801

3799-
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
3802+
if tokenizer_config_path.is_file():
3803+
with open(tokenizer_config_path, "r", encoding="utf-8") as fp:
3804+
tokenizer_config_json = json.load(fp)
3805+
3806+
add_prefix = tokenizer.add_prefix_space
3807+
remove_whitespaces = tokenizer.clean_up_tokenization_spaces
3808+
precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
3809+
3810+
vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size)
3811+
else:
3812+
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
3813+
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
3814+
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
3815+
3816+
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
3817+
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
3818+
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
3819+
3820+
tokenizer = SentencePieceProcessor()
3821+
tokenizer.LoadFromFile(str(tokenizer_path))
3822+
3823+
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
38003824

38013825
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
38023826
scores: list[float] = [-10000.0] * vocab_size
38033827
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
38043828

3805-
for token_id in range(tokenizer.vocab_size()):
3806-
piece = tokenizer.IdToPiece(token_id)
3807-
text = piece.encode("utf-8")
3808-
score = tokenizer.GetScore(token_id)
3829+
if isinstance(tokenizer, SentencePieceProcessor):
3830+
for token_id in range(tokenizer.vocab_size()):
3831+
piece = tokenizer.IdToPiece(token_id)
3832+
text = piece.encode("utf-8")
3833+
score = tokenizer.GetScore(token_id)
38093834

3810-
toktype = SentencePieceTokenTypes.NORMAL
3811-
if tokenizer.IsUnknown(token_id):
3812-
toktype = SentencePieceTokenTypes.UNKNOWN
3813-
elif tokenizer.IsControl(token_id):
3814-
toktype = SentencePieceTokenTypes.CONTROL
3815-
elif tokenizer.IsUnused(token_id):
3816-
toktype = SentencePieceTokenTypes.UNUSED
3817-
elif tokenizer.IsByte(token_id):
3818-
toktype = SentencePieceTokenTypes.BYTE
3835+
toktype = SentencePieceTokenTypes.NORMAL
3836+
if tokenizer.IsUnknown(token_id):
3837+
toktype = SentencePieceTokenTypes.UNKNOWN
3838+
elif tokenizer.IsControl(token_id):
3839+
toktype = SentencePieceTokenTypes.CONTROL
3840+
elif tokenizer.IsUnused(token_id):
3841+
toktype = SentencePieceTokenTypes.UNUSED
3842+
elif tokenizer.IsByte(token_id):
3843+
toktype = SentencePieceTokenTypes.BYTE
38193844

3820-
tokens[token_id] = text
3821-
scores[token_id] = score
3822-
toktypes[token_id] = toktype
3845+
tokens[token_id] = text
3846+
scores[token_id] = score
3847+
toktypes[token_id] = toktype
3848+
else:
3849+
added_vocab = tokenizer.get_added_vocab()
3850+
unk_token = tokenizer_config_json.get("unk_token")
3851+
unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3))
3852+
3853+
for token_id in range(vocab_size):
3854+
piece = tokenizer._convert_id_to_token(token_id)
3855+
text = piece.encode("utf-8")
3856+
score = tokenizer_json["model"]["vocab"][token_id][1]
3857+
3858+
toktype = SentencePieceTokenTypes.NORMAL
3859+
if token_id == unk_token_id:
3860+
toktype = SentencePieceTokenTypes.UNKNOWN
3861+
elif token_id in tokenizer.all_special_ids:
3862+
toktype = SentencePieceTokenTypes.CONTROL
3863+
elif token_id in added_vocab.values():
3864+
toktype = SentencePieceTokenTypes.USER_DEFINED
3865+
# No reliable way to detect this, but jina doesn't have any
3866+
# elif tokenizer.IsByte(token_id):
3867+
# toktype = SentencePieceTokenTypes.BYTE
3868+
3869+
tokens[token_id] = text
3870+
scores[token_id] = score
3871+
toktypes[token_id] = toktype
38233872

38243873
if vocab_size > len(tokens):
38253874
pad_count = vocab_size - len(tokens)
@@ -3829,15 +3878,16 @@ def _xlmroberta_set_vocab(self) -> None:
38293878
scores.append(-1000.0)
38303879
toktypes.append(SentencePieceTokenTypes.UNUSED)
38313880

3832-
# realign tokens (see HF tokenizer code)
3833-
tokens = [b'<s>', b'<pad>', b'</s>', b'<unk>'] + tokens[3:-1]
3834-
scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1]
3835-
toktypes = [
3836-
SentencePieceTokenTypes.CONTROL,
3837-
SentencePieceTokenTypes.CONTROL,
3838-
SentencePieceTokenTypes.CONTROL,
3839-
SentencePieceTokenTypes.UNKNOWN,
3840-
] + toktypes[3:-1]
3881+
if isinstance(tokenizer, SentencePieceProcessor):
3882+
# realign tokens (see HF tokenizer code)
3883+
tokens = [b'<s>', b'<pad>', b'</s>', b'<unk>'] + tokens[3:-1]
3884+
scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1]
3885+
toktypes = [
3886+
SentencePieceTokenTypes.CONTROL,
3887+
SentencePieceTokenTypes.CONTROL,
3888+
SentencePieceTokenTypes.CONTROL,
3889+
SentencePieceTokenTypes.UNKNOWN,
3890+
] + toktypes[3:-1]
38413891

38423892
self.gguf_writer.add_tokenizer_model("t5")
38433893
self.gguf_writer.add_tokenizer_pre("default")

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,7 @@ class MODEL_TENSOR(IntEnum):
10361036
MODEL_TENSOR.POS_EMBD,
10371037
MODEL_TENSOR.OUTPUT_NORM,
10381038
MODEL_TENSOR.ATTN_OUT_NORM,
1039+
MODEL_TENSOR.ATTN_QKV,
10391040
MODEL_TENSOR.ATTN_Q,
10401041
MODEL_TENSOR.ATTN_K,
10411042
MODEL_TENSOR.ATTN_V,

gguf-py/gguf/tensor_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ class TensorNameMap:
157157
"h.{bid}.attn.c_attn", # gpt2
158158
"transformer.h.{bid}.mixer.Wqkv", # phi2
159159
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
160+
"encoder.layers.{bid}.mixer.Wqkv", # jina
160161
"model.layers.{bid}.self_attn.qkv_proj", # phi3
161162
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
162163
"transformer.layers.{bid}.attn.qkv_proj", # openelm
@@ -224,6 +225,7 @@ class TensorNameMap:
224225
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
225226
"model.layers.{bid}.attention.wo", # internlm2
226227
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
228+
"encoder.layers.{bid}.mixer.out_proj", # jina
227229
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
228230
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
229231
"encoder.layers.{bid}.self_attention.dense", # chatglm

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
450450
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
451451
{ LLM_TENSOR_POS_EMBD, "position_embd" },
452452
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
453+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
453454
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
454455
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
455456
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },

src/llama-model.cpp

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,7 +2132,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
21322132
for (int i = 0; i < n_layer; ++i) {
21332133
auto & layer = layers[i];
21342134

2135-
if (arch == LLM_ARCH_BERT) {
2135+
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
2136+
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
2137+
2138+
if (!layer.wqkv) {
21362139
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
21372140
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
21382141

@@ -2141,12 +2144,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
21412144

21422145
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
21432146
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
2144-
} else {
2145-
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2146-
}
2147-
2148-
if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
2149-
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
21502147
}
21512148

21522149
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
@@ -5910,48 +5907,44 @@ struct llm_build_bert : public llm_graph_context {
59105907
ggml_tensor * Vcur;
59115908

59125909
// self-attention
5913-
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
5914-
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
5915-
5916-
if (model.layers[il].attn_q_norm) {
5917-
Qcur = build_norm(Qcur,
5918-
model.layers[il].attn_q_norm,
5919-
model.layers[il].attn_q_norm_b,
5920-
LLM_NORM, il);
5921-
}
5922-
5923-
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
5924-
5925-
if (model.layers[il].attn_k_norm) {
5926-
Kcur = build_norm(Kcur,
5927-
model.layers[il].attn_k_norm,
5928-
model.layers[il].attn_k_norm_b,
5929-
LLM_NORM, il);
5930-
}
5931-
5932-
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
5933-
5934-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5935-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5936-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
5937-
} else {
5938-
// compute Q and K and RoPE them
5910+
if (model.layers[il].wqkv) {
59395911
cur = build_lora_mm(model.layers[il].wqkv, cur);
59405912
cb(cur, "wqkv", il);
59415913

5942-
if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
5914+
if (model.layers[il].bqkv) {
59435915
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
59445916
cb(cur, "bqkv", il);
59455917
}
59465918

59475919
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
59485920
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
59495921
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
5922+
} else {
5923+
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
5924+
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
5925+
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
5926+
}
59505927

5951-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5952-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5953-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
5928+
if (model.layers[il].attn_q_norm) {
5929+
Qcur = build_norm(Qcur,
5930+
model.layers[il].attn_q_norm,
5931+
model.layers[il].attn_q_norm_b,
5932+
LLM_NORM, il);
5933+
}
5934+
5935+
if (model.layers[il].attn_k_norm) {
5936+
Kcur = build_norm(Kcur,
5937+
model.layers[il].attn_k_norm,
5938+
model.layers[il].attn_k_norm_b,
5939+
LLM_NORM, il);
5940+
}
5941+
5942+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5943+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5944+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
59545945

5946+
// RoPE
5947+
if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
59555948
Qcur = ggml_rope_ext(
59565949
ctx0, Qcur, inp_pos, nullptr,
59575950
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,

0 commit comments

Comments
 (0)
0