8000 Support SPM infill by CISC · Pull Request #1492 · abetlen/llama-cpp-python · GitHub
[go: up one dir, main page]

Skip to content

Support SPM infill #1492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/high_level_api/high_level_api_infill.py
10000
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import argparse

from llama_cpp import Llama

parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="../models/7B/ggml-models.bin")
parser.add_argument("-p", "--prompt", type=str, default="def add(")
parser.add_argument("-s", "--suffix", type=str, default="\n return sum\n\n")
parser.add_argument("-i", "--spm-infill", action='store_true')
args = parser.parse_args()

llm = Llama(model_path=args.model, n_gpu_layers=-1, spm_infill=args.spm_infill)

output = llm.create_completion(
temperature = 0.0,
repeat_penalty = 1.0,
prompt = args.prompt,
suffix = args.suffix,
)

# Models sometimes repeat suffix in response, attempt to filter that
response = output["choices"][0]["text"]
response_stripped = response.rstrip()
unwanted_response_suffix = args.suffix.rstrip()
unwanted_response_length = len(unwanted_response_suffix)

filtered = False
if unwanted_response_suffix and response_stripped[-unwanted_response_length:] == unwanted_response_suffix:
response = response_stripped[:-unwanted_response_length]
filtered = True

print(f"Fill-in-Middle completion{' (filtered)' if filtered else ''}:\n\n{args.prompt}\033[32m{response}\033[{'33' if filtered else '0'}m{args.suffix}\033[0m")

8 changes: 8 additions & 0 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ def token_eot(self) -> int:
assert self.model is not None
return llama_cpp.llama_token_eot(self.model)

def add_bos_token(self) -> int:
assert self.model is not None
return llama_cpp.llama_add_bos_token(self.model)

def add_eos_token(self) -> int:
assert self.model is not None
return llama_cpp.llama_add_eos_token(self.model)

# Tokenization

def tokenize(self, text: bytes, add_bos: bool, special: bool):
Expand Down
73 changes: 46 additions & 27 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
type_k: Optional[int] = None,
type_v: Optional[int] = None,
# Misc
spm_infill: bool = False,
verbose: bool = True,
# Extra Params
**kwargs, # type: ignore
Expand Down Expand Up @@ -185,6 +186,7 @@ def __init__(
verbose: Print verbose output to stderr.
type_k: KV cache data type for K (default: f16)
type_v: KV cache data type for V (default: f16)
spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.

Raises:
ValueError: If the model path does not exist.
Expand Down Expand Up @@ -343,6 +345,8 @@ def __i 10000 nit__(
self.lora_scale = lora_scale
self.lora_path = lora_path

self.spm_infill = spm_infill

if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}")

Expand Down Expand Up @@ -972,14 +976,33 @@ def _create_completion(

completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
bos_token_id: int = self.token_bos()
cls_token_id: int = self._model.token_cls()
sep_token_id: int = self._model.token_sep()
prefix_token_id: int = self._model.token_prefix()
middle_token_id: int = self._model.token_middle()
suffix_token_id: int = self._model.token_suffix()
add_space_prefix: bool = self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
bos_tokens: List[int] = [cls_token_id if cls_token_id != -1 else bos_token_id]
eos_tokens: List[int] = [sep_token_id if sep_token_id != -1 else self.token_eos()]

if (isinstance(prompt, list) and suffix is None) or self._model.add_bos_token() == 0 or bos_tokens[:1] == [-1]:
bos_tokens = []

if (isinstance(prompt, list) and suffix is None) or (self._model.add_eos_token() != 1 and sep_token_id == -1):
eos_tokens = []

suffix_space_prefix: int = 0
# Tokenizer hack to remove leading space
if add_space_prefix and suffix_token_id >= 0 and suffix:
suffix = "☺" + suffix
suffix_space_prefix = 2

# If prompt is empty, initialize completion with BOS token to avoid
# detokenization including a space at the beginning of the completion
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()]
completion_tokens: List[int] = [] if len(prompt) > 0 else [bos_token_id]
# Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens: List[int] = (
prefix_tokens: List[int] = (
(
[prefix_token_id]
if prefix_token_id >= 0 and suffix is not None
Expand All @@ -988,38 +1011,33 @@ def _create_completion(
+
(
(
self.tokenize(prompt.encode("utf-8"), add_bos=(prefix_token_id < 0 or suffix is None), special=(prefix_token_id < 0 or suffix is None))
self.tokenize(prompt.encode("utf-8"), add_bos=False, special=(prefix_token_id < 0 or suffix is None))
if prompt != ""
else (
[]
if prefix_token_id >= 0 and suffix is not None
else [self.token_bos()]
)
else []
)
if isinstance(prompt, str)
else prompt
)
+
)
suffix_tokens: List[int] = (
(
[suffix_token_id]
+
(
[suffix_token_id]
+
(
self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)
if suffix
else []
)
self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)[suffix_space_prefix:]
if suffix
else []
)
if suffix_token_id >= 0 and suffix is not None
else []
)
+
(
[middle_token_id]
if middle_token_id >= 0 and suffix is not None
else []
)
if suffix_token_id >= 0 and suffix is not None
else []
)
middle_tokens: List[int] = (
[middle_token_id]
if middle_token_id >= 0 and suffix is not None
else []
)
prompt_tokens: List[int] = bos_tokens + ((suffix_tokens + prefix_tokens + middle_tokens) if self.spm_infill else (prefix_tokens + suffix_tokens + middle_tokens)) + eos_tokens
text: bytes = b""
returned_tokens: int = 0
stop = (
Expand Down Expand Up @@ -1176,7 +1194,7 @@ def logit_bias_processor(
# not sure how to handle this branch when dealing
# with CJK output, so keep it unchanged
for token in remaining_tokens:
if token == self.token_bos():
if token == bos_token_id:
continue
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
# Check if stop sequence is in the token
Expand Down Expand Up @@ -1303,7 +1321,7 @@ def logit_bias_processor(

logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
if token == self.token_bos():
if token == bos_token_id:
continue
token_str = self.detokenize([token]).decode(
"utf-8", errors="ignore"
Expand Down Expand Up @@ -1431,7 +1449,7 @@ def logit_bias_processor(
for idx, (token, token_str, logprobs_token) in enumerate(
zip(all_tokens, all_token_strs, all_logprobs)
):
if token == self.token_bos():
if token == bos_token_id:
continue
text_offsets.append(
text_offset
Expand Down Expand Up @@ -1858,6 +1876,7 @@ def __getstate__(self):
type_k=self.context_params.type_k,
type_v=self.context_params.type_v,
# Misc
spm_infill=self.spm_infill,
verbose=self.verbose,
)

Expand Down
0