8000 Bugfix: Stop sequences can be strings · chabotsi/llama-cpp-python@a8cd169 · GitHub
[go: up one dir, main page]

Skip to content

Commit a8cd169

Browse files
committed
Bugfix: Stop sequences can be strings
1 parent f0812c4 commit a8cd169

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

llama_cpp/llama.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ def _create_completion(
602602
top_p: float = 0.95,
603603
logprobs: Optional[int] = None,
604604
echo: bool = False,
605-
stop: Optional[List[str]] = [],
605+
stop: Optional[Union[str, List[str]]] = [],
606606
frequency_penalty: float = 0.0,
607607
presence_penalty: float = 0.0,
608608
repeat_penalty: float = 1.1,
@@ -624,7 +624,7 @@ def _create_completion(
624624
)
625625
text: bytes = b""
626626
returned_tokens: int = 0
627-
stop = stop if stop is not None else []
627+
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
628628
model_name: str = model if model is not None else self.model_path
629629

630630
if self.verbose:
@@ -973,7 +973,7 @@ def create_completion(
973973
top_p: float = 0.95,
974974
logprobs: Optional[int] = None,
975975
echo: bool = False,
976-
stop: Optional[List[str]] = [],
976+
stop: Optional[Union[str, List[str]]] = [],
977977
frequency_penalty: float = 0.0,
978978
presence_penalty: float = 0.0,
979979
repeat_penalty: float = 1.1,
@@ -1042,7 +1042,7 @@ def __call__(
10421042
top_p: float = 0.95,
10431043
logprobs: Optional[int] = None,
10441044
echo: bool = False,
1045-
stop: Optional[List[str]] = [],
1045+
stop: Optional[Union[str, List[str]]] = [],
10461046
frequency_penalty: float = 0.0,
10471047
presence_penalty: float = 0.0,
10481048
repeat_penalty: float = 1.1,
@@ -1162,7 +1162,7 @@ def create_chat_completion(
11621162
top_p: float = 0.95,
11631163
top_k: int = 40,
11641164
stream: bool = False,
1165-
stop: Optional[List[str]] = [],
1165+
stop: Optional[Union[str, List[str]]] = [],
11661166
max_tokens: int = 256,
11671167
presence_penalty: float = 0.0,
11681168
frequency_penalty: float = 0.0,
@@ -1188,7 +1188,7 @@ def create_chat_completion(
11881188
Returns:
11891189
Generated chat completion or a stream of chat completion chunks.
11901190
"""
1191-
stop = stop if stop is not None else []
1191+
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
11921192
chat_history = "".join(
11931193
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
11941194
for message in messages

llama_cpp/server/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23
import multiprocessing
34
from threading import Lock
45
from typing import List, Optional, Union, Iterator, Dict
@@ -203,7 +204,7 @@ class CreateCompletionRequest(BaseModel):
203204
default=False,
204205
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
205206
)
206-
stop: Optional[List[str]] = stop_field
207+
stop: Optional[Union[str, List[str]]] = stop_field
207208
stream: bool = stream_field
208209
logprobs: Optional[int] = Field(
209210
default=None,

0 commit comments

Comments
 (0)
0