8000 Incorporate embedding pooling layer fixes by iamlemec · Pull Request #1194 · abetlen/llama-cpp-python · GitHub
[go: up one dir, main page]

Skip to content

Incorporate embedding pooling layer fixes #1194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def embed(
"""
assert self._ctx.ctx is not None
n_embd = self.n_embd()
n_ctx = self.n_ctx()
n_batch = self.n_batch

if self.context_params.embedding == False:
raise RuntimeError(
Expand All @@ -782,54 +782,55 @@ def embed(

# decode and fetch embeddings
data: List[List[float]] = []
def decode_batch(sizes: List[int]):
def decode_batch(n_seq: int):
assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self._ctx.decode(self._batch)
self._batch.reset()

# store embeddings
for i, s in enumerate(sizes):
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
for i in range(n_seq):
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
:n_embd
]
norm = np.linalg.norm(embedding) if normalize else s
embedding: List[float] = [v / float(norm) for v in embedding]
if normalize:
norm = float(np.linalg.norm(embedding))
embedding = [v / norm for v in embedding]
data.append(embedding)

# init state
total_tokens = 0
t_batch = 0
s_sizes: List[int] = []
p_batch = 0

# accumulate batches and encode
for text in inputs:
tokens = self.tokenize(text.encode("utf-8"))
if truncate:
tokens = tokens[:n_ctx]
tokens = tokens[:n_batch]

n_tokens = len(tokens)
total_tokens += n_tokens

# check for overrun
if n_tokens > n_ctx:
if n_tokens > n_batch:
raise ValueError(
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}"
)

# time to eval batch
if t_batch + n_tokens > self._n_ctx:
decode_batch(s_sizes)
if t_batch + n_tokens > n_batch:
decode_batch(p_batch)
t_batch = 0
s_sizes = []
p_batch = 0

# add to batch
self._batch.add_sequence(tokens, len(s_sizes), False)
self._batch.add_sequence(tokens, p_batch, False)
t_batch += n_tokens
s_sizes.append(n_tokens)
p_batch += 1

# hanlde last batch
decode_batch(s_sizes)
decode_batch(p_batch)

if self.verbose:
llama_cpp.llama_print_timings(self._ctx.ctx)
Expand Down
0