10000 Add support to get embeddings from high-level api. Closes #4 · coderonion/llama-cpp-python@70b8a1e · GitHub
[go: up one dir, main page]

Skip to content

Commit 70b8a1e

Browse files
committed
Add support to get embeddings from high-level api. Closes abetlen#4
1 parent 9ba5c3c commit 70b8a1e

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

examples/high_level_api_embedding.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import json
2+
import argparse
3+
4+
from llama_cpp import Llama
5+
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument("-m", "--model", type=str, default=".//models/...")
8+
args = parser.parse_args()
9+
10+
llm = Llama(model_path=args.model, embedding=True)
11+
12+
print(llm.embed("Hello world!"))

llama_cpp/llama.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,20 @@ def detokenize(self, tokens: List[int]) -> bytes:
105105
output += llama_cpp.llama_token_to_str(self.ctx, token)
106106
return output
107107

108+
def embed(self, text: str):
109+
"""Embed a string.
110+
111+
Args:
112+
text: The utf-8 encoded string to embed.
113+
114+
Returns:
115+
A list of embeddings.
116+
"""
117+
tokens = self.tokenize(text.encode("utf-8"))
118+
self._eval(tokens, 0)
119+
embeddings = llama_cpp.llama_get_embeddings(self.ctx)
120+
return embeddings[:llama_cpp.llama_n_embd(self.ctx)]
121+
108122
def _eval(self, tokens: List[int], n_past):
109123
rc = llama_cpp.llama_eval(
110124
self.ctx,

0 commit comments

Comments
 (0)
0