From ee84ca12d26bf07d29941b324ef9701139c02a25 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 15 Feb 2024 01:31:25 -0600 Subject: [PATCH 1/2] remove division by token count --- llama_cpp/llama.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 3e09a20b5..a5158f88a 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -782,25 +782,26 @@ def embed( # decode and fetch embeddings data: List[List[float]] = [] - def decode_batch(sizes: List[int]): + def decode_batch(n_seq: int): assert self._ctx.ctx is not None llama_cpp.llama_kv_cache_clear(self._ctx.ctx) self._ctx.decode(self._batch) self._batch.reset() # store embeddings - for i, s in enumerate(sizes): - embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[ + for i in range(n_seq): + embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[ :n_embd ] - norm = np.linalg.norm(embedding) if normalize else s - embedding: List[float] = [v / float(norm) for v in embedding] + if normalize: + norm = float(np.linalg.norm(embedding)) + embedding = [v / norm for v in embedding] data.append(embedding) # init state total_tokens = 0 t_batch = 0 - s_sizes: List[int] = [] + p_batch = 0 # accumulate batches and encode for text in inputs: @@ -819,17 +820,17 @@ def decode_batch(sizes: List[int]): # time to eval batch if t_batch + n_tokens > self._n_ctx: - decode_batch(s_sizes) + decode_batch(p_batch) t_batch = 0 - s_sizes = [] + p_batch = 0 # add to batch - self._batch.add_sequence(tokens, len(s_sizes), False) + self._batch.add_sequence(tokens, p_batch, False) t_batch += n_tokens - s_sizes.append(n_tokens) + p_batch += 1 # hanlde last batch - decode_batch(s_sizes) + decode_batch(p_batch) if self.verbose: llama_cpp.llama_print_timings(self._ctx.ctx) From fa7f1cd45e80f350d91e6b83aceaa7bccf008201 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 15 Feb 2024 11:43:53 -0600 Subject: [PATCH 2/2] truncate to n_batch, not n_ctx --- llama_cpp/llama.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index a5158f88a..f3c7b4fff 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -762,7 +762,7 @@ def embed( """ assert self._ctx.ctx is not None n_embd = self.n_embd() - n_ctx = self.n_ctx() + n_batch = self.n_batch if self.context_params.embedding == False: raise RuntimeError( @@ -807,19 +807,19 @@ def decode_batch(n_seq: int): for text in inputs: tokens = self.tokenize(text.encode("utf-8")) if truncate: - tokens = tokens[:n_ctx] + tokens = tokens[:n_batch] n_tokens = len(tokens) total_tokens += n_tokens # check for overrun - if n_tokens > n_ctx: + if n_tokens > n_batch: raise ValueError( - f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}" + f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}" ) # time to eval batch - if t_batch + n_tokens > self._n_ctx: + if t_batch + n_tokens > n_batch: decode_batch(p_batch) t_batch = 0 p_batch = 0