8000 fix: Incorporate embedding pooling layer fixes (#1194) · coderonion/llama-cpp-python@7bb91f0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7bb91f0

Browse files
authored
fix: Incorporate embedding pooling layer fixes (abetlen#1194)
* remove division by token count * truncate to n_batch, not n_ctx
1 parent ae71ad1 commit 7bb91f0

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

llama_cpp/llama.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ def embed(
762762
"""
763763
assert self._ctx.ctx is not None
764764
n_embd = self.n_embd()
765-
n_ctx = self.n_ctx()
765+
n_batch = self.n_batch
766766

767767
if self.context_params.embedding == False:
768768
raise RuntimeError(
@@ -782,54 +782,55 @@ def embed(
782782

783783
# decode and fetch embeddings
784784
data: List[List[float]] = []
785-
def decode_batch(sizes: List[int]):
785+
def decode_batch(n_seq: int):
786786
assert self._ctx.ctx is not None
787787
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
788788
self._ctx.decode(self._batch)
789789
self._batch.reset()
790790

791791
# store embeddings
792-
for i, s in enumerate(sizes):
793-
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
792+
for i in range(n_seq):
793+
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
794794
:n_embd
795795
]
796-
norm = np.linalg.norm(embedding) if normalize else s
797-
embedding: List[float] = [v / float(norm) for v in embedding]
796+
if normalize:
797+
norm = float(np.linalg.norm(embedding))
798+
embedding = [v / norm for v in embedding]
798799
data.append(embedding)
799800

800801
# init state
801802
total_tokens = 0
802803
t_batch = 0
803-
s_sizes: List[int] = []
804+
p_batch = 0
804805

805806
# accumulate batches and encode
806807
for text in inputs:
807808
tokens = self.tokenize(text.encode("utf-8"))
808809
if truncate:
809-
tokens = tokens[:n_ctx]
810+
tokens = tokens[:n_batch]
810811

811812
n_tokens = len(tokens)
812813
total_tokens += n_tokens
813814

814815
# check for overrun
815-
if n_tokens > n_ctx:
816+
if n_tokens > n_batch:
816817
raise ValueError(
817-
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
818+
f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}"
818819
)
819820

820821
# time to eval batch
821-
if t_batch + n_tokens > self._n_ctx:
822-
decode_batch(s_sizes)
822+
if t_batch + n_tokens > n_batch:
823+
decode_batch(p_batch)
823824
t_batch = 0
824-
s_sizes = []
825+
p_batch = 0
825826

826827
# add to batch
827-
self._batch.add_sequence(tokens, len(s_sizes), False)
828+
self._batch.add_sequence(tokens, p_batch, False)
828829
t_batch += n_tokens
829-
s_sizes.append(n_tokens)
830+
p_batch += 1
830831

831832
# hanlde last batch
832-
decode_batch(s_sizes)
833+
decode_batch(p_batch)
833834

834835
if self.verbose:
835836
llama_cpp.llama_print_timings(self._ctx.ctx)

0 commit comments

Comments
 (0)
0