8000 feat: Support batch embeddings (#1186) · coderonion/llama-cpp-python@d7a6791 · GitHub
[go: up one dir, main page]

Skip to content

Commit d7a6791

Browse files
iamlemecabetlen
andauthored
feat: Support batch embeddings (abetlen#1186)
* handle batched embeddings * fix normalization issue * fix type hints, ensure no breaking changes to embed * Clear kv cache / reset internal state after embedding complete --------- Co-authored-by: Andrei <abetlen@gmail.com>
1 parent 36b8432 commit d7a6791

File tree

2 files changed

+123
-34
lines changed

2 files changed

+123
-34
lines changed

llama_cpp/_internals.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,14 @@ def __del__(self):
510510
self._llama_batch_free(self.batch)
511511
self.batch = None
512512

513+
def n_tokens(self) -> int:
514+
assert self.batch is not None
515+
return self.batch.n_tokens
516+
517+
def reset(self):
518+
assert self.batch is not None
519+
self.batch.n_tokens = 0
520+
513521
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
514522
assert self.batch is not None
515523
n_tokens = len(batch)
@@ -522,6 +530,20 @@ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
522530
self.batch.logits[i] = logits_all
523531
self.batch.logits[n_tokens - 1] = True
524532

533+
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
534+
assert self.batch is not None
535+
n_tokens = len(batch)
536+
n_tokens0 = self.batch.n_tokens
537+
self.batch.n_tokens += n_tokens
538+
for i in range(n_tokens):
539+
j = n_tokens0 + i
540+
self.batch.token[j] = batch[i]
541+
self.batch.pos[j] = i
542+
self.batch.seq_id[j][0] = seq_id
543+
self.batch.n_seq_id[j] = 1
544+
self.batch.logits[j] = logits_all
545+
self.batch.logits[n_tokens - 1] = True
546+
525547

526548
class _LlamaTokenDataArray:
527549
def __init__(self, *, n_vocab: int):

llama_cpp/llama.py

Lines changed: 101 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -717,10 +717,53 @@ def create_embedding(
717717
Returns:
718718
An embedding object.
719719
"""
720-
assert self._ctx.ctx is not None
721720
assert self._model.model is not None
722721
model_name: str = model if model is not None else self.model_path
723722

723+
# get numeric embeddings
724+
embeds: List[List[float]]
725+
total_tokens: int
726+
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
727+
728+
# convert to CreateEmbeddingResponse
729+
data: List[Embedding] = [
730+
{
731+
"object": "embedding",
732+
"embedding": emb,
733+
"index": idx,
734+
}
735+
for idx, emb in enumerate(embeds)
736+
]
737+
738+
return {
739+
"object": "list",
740+
"data": data,
741+
"model": model_name,
742+
"usage": {
743+
"prompt_tokens": total_tokens,
744+
"total_tokens": total_tokens,
745+
},
746+
}
747+
748+
def embed(
749+
self,
750+
input: Union[str, List[str]],
751+
normalize: bool = True,
752+
truncate: bool = True,
753+
return_count: bool = False,
754+
):
755+
"""Embed a string.
756+
757+
Args:
758+
input: The utf-8 encoded string to embed.
759+
760+
Returns:
761+
A list of embeddings
762+
"""
763+
assert self._ctx.ctx is not None
764+
n_embd = self.n_embd()
765+
n_ctx = self.n_ctx()
766+
724767
if self.context_params.embedding == False:
725768
raise RuntimeError(
726769
"Llama model must be created with embedding=True to call this method"
@@ -734,48 +777,72 @@ def create_embedding(
734777
else:
735778
inputs = input
736779

737-
data: List[Embedding] = []
780+
# reset batch
781+
self._batch.reset()
782+
783+
# decode and fetch embeddings
784+
data: List[List[float]] = []
785+
def decode_batch(sizes: List[int]):
786+
assert self._ctx.ctx is not None
787+
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
788+
self._ctx.decode(self._batch)
789+
self._batch.reset()
790+
791+
# store embeddings
792+
for i, s in enumerate(sizes):
793+
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
794+
:n_embd
795+
]
796+
norm = np.linalg.norm(embedding) if normalize else s
797+
embedding: List[float] = [v / float(norm) for v in embedding]
798+
data.append(embedding)
799+
800+
# init state
738801
total_tokens = 0
739-
for index, input in enumerate(inputs):
740-
tokens = self.tokenize(input.encode("utf-8"), special=True)
741-
self.reset()
742-
self.eval(tokens)
802+
t_batch = 0
803+
s_sizes: List[int] = []
804+
805+
# accumulate batches and encode
806+
for text in inputs:
807+
tokens = self.tokenize(text.encode("utf-8"))
808+
if truncate:
809+
tokens = tokens[:n_ctx]
810+
743811
n_tokens = len(tokens)
744812
total_tokens += n_tokens
745-
embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[
746-
: llama_cpp.llama_n_embd(self._model.model)
747-
]
748813

749-
data.append(
750-
{
751-
"object": "embedding",
752-
"embedding": embedding,
753-
"index": index,
754-
}
755-
)
814+
# check for overrun
815+
if n_tokens > n_ctx:
816+
raise ValueError(
817+
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
818+
)
819+
820+
# time to eval batch
821+
if t_batch + n_tokens > self._n_ctx:
822+
decode_batch(s_sizes)
823+
t_batch = 0
824+
s_sizes = []
825+
826+
# add to batch
827+
self._batch.add_sequence(tokens, len(s_sizes), False)
828+
t_batch += n_tokens
829+
s_sizes.append(n_tokens)
830+
831+
# hanlde last batch
832+
decode_batch(s_sizes)
833+
756834
if self.verbose:
757835
llama_cpp.llama_print_timings(self._ctx.ctx)
758836

759-
return {
760-
"object": "list",
761-
"data": data,
762-
"model": model_name,
763-
"usage": {
764-
"prompt_tokens": total_tokens,
765-
"total_tokens": total_tokens,
766-
},
767-
}
768-
769-
def embed(self, input: str) -> List[float]:
770-
"""Embed a string.
837+
output = data[0] if isinstance(input, str) else data
771838

772-
Args:
773-
input: The utf-8 encoded string to embed.
839+
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
840+
self.reset()
774841

775-
Returns:
776-
A list of embeddings
777-
"""
778-
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
842+
if return_count:
843+
return output, total_tokens
844+
else:
845+
return output
779846

780847
def _create_completion(
781848
self,

0 commit comments

Comments
 (0)
0