8000 Format · Zephyr800/llama-cpp-python@7a3f878 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7a3f878

Browse files
committed
Format
1 parent 422ebc8 commit 7a3f878

File tree

3 files changed

+104
-71
lines changed

3 files changed

+104
-71
lines changed

llama_cpp/llama.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,9 +1036,9 @@ def eval(self, tokens: Sequence[int]):
10361036
offset = (
10371037
0 if self.context_params.logits_all else n_tokens - 1
10381038
) # NOTE: Only save the last token logits if logits_all is False
1039-
self.scores[n_past + offset : n_past + n_tokens, :].reshape(
1040-
-1
1041-
)[:] = self._ctx.get_logits()[offset * cols: rows * cols]
1039+
self.scores[n_past + offset : n_past + n_tokens, :].reshape(-1)[
1040+
:
1041+
] = self._ctx.get_logits()[offset * cols : rows * cols]
10421042
# Update n_tokens
10431043
self.n_tokens += n_tokens
10441044

@@ -1135,7 +1135,9 @@ def sample(
11351135
else:
11361136
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
11371137
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
1138-
self._ctx.sample_typical(candidates=self._candidates, p=typical_p, min_keep=1)
1138+
self._ctx.sample_typical(
1139+
candidates=self._candidates, p=typical_p, min_keep=1
1140+
)
11391141
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
11401142
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
11411143
self._ctx.sample_temp(candidates=self._candidates, temp=temp)

llama_cpp/llama_chat_format.py

Lines changed: 73 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ def format_phind(
532532
_prompt = _format_add_colon_single(_system_message, _messages, _sep)
533533
return ChatFormatterResponse(prompt=_prompt)
534534

535+
535536
@register_chat_format("intel")
536537
def format_intel(
537538
messages: List[llama_types.ChatCompletionRequestMessage],
@@ -588,6 +589,7 @@ def format_mistrallite(
588589
_prompt = _format_no_colon_single(system_message, _messages, _sep)
589590
return ChatFormatterResponse(prompt=_prompt)
590591

592+
591593
@register_chat_format("chatml")
592594
def format_chatml(
593595
messages: List[llama_types.ChatCompletionRequestMessage],
@@ -604,6 +606,7 @@ def format_chatml(
604606
_prompt = _format_chatml(system_message, _messages, _sep)
605607
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
606608

609+
607610
@register_chat_format("openchat")
608611
def format_openchat(
609612
messages: List[llama_types.ChatCompletionRequestMessage],
@@ -612,7 +615,9 @@ def format_openchat(
612615
system_template = "{system_message}<|end_of_turn|>"
613616
system_message = _get_system_message(messages)
614617
system_message = system_template.format(system_message=system_message)
615-
_roles = dict(user="GPT4 Correct User: ", assistant="<|end_of_turn|>GPT4 Correct Assistant: ")
618+
_roles = dict(
619+
user="GPT4 Correct User: ", assistant="<|end_of_turn|>GPT4 Correct Assistant: "
620+
)
616621
_sep = "<|end_of_turn|>"
617622
_messages = _map_roles(messages, _roles)
618623
_messages.append((_roles["assistant"], None))
@@ -651,46 +656,60 @@ def functionary_chat_handler(
651656
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
652657
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
653658

654-
def generate_type_definition(param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs) -> str:
655-
indent = ' ' * indent_level
656-
if '$ref' in param:
659+
def generate_type_definition(
660+
param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs
661+
) -> str:
662+
indent = " " * indent_level
663+
if "$ref" in param:
657664
# Reference to a shared definition
658-
ref_name = param['$ref'].split('/')[-1] # Extract the type name from the reference
665+
ref_name = param["$ref"].split("/")[
666+
-1
667+
] # Extract the type name from the reference
659668
return ref_name
660-
elif param.get('type') == 'array':
661-
items = param.get('items', {})
669+
elif param.get("type") == "array":
670+
items = param.get("items", {})
662671
item_type = generate_type_definition(items, indent_level + 1, shared_defs)
663672
return f"Array<{item_type}>"
664-
elif param.get('type') == 'object':
665-
properties = param.get('properties', {})
673+
elif param.get("type") == "object":
674+
properties = param.get("properties", {})
666675
nested_schema = "{\n"
667676
for nested_param_name, nested_param in properties.items():
668-
nested_param_type = generate_type_definition(nested_param, indent_level + 1, shared_defs)
669-
nested_schema += f"{indent} {nested_param_name}: {nested_param_type},\n"
677+
nested_param_type = generate_type_definition(
678+
nested_param, indent_level + 1, shared_defs
679+
)
680+
nested_schema += (
681+
f"{indent} {nested_param_name}: {nested_param_type},\n"
682+
)
670683
nested_schema += indent + "}"
671684
return nested_schema
672-
elif 'enum' in param:
685+
elif "enum" in param:
673686
# Enum type
674-
return " | ".join([f'"{enum_value}"' for enum_value in param['enum']])
687+
return " | ".join([f'"{enum_value}"' for enum_value in param["enum"]])
675688
else:
676689
# Simple type
677-
return param.get('type', 'any')
690+
return param.get("type", "any")
678691

679692
def generate_shared_definitions(shared_defs, indent_level: int) -> str:
680-
indent = ' ' * indent_level
693+
indent = " " * indent_level
681694
shared_definitions = ""
682695
for def_name, def_properties in shared_defs.items():
683696
shared_definitions += f"{indent}type {def_name} = "
684-
if def_properties.get('type') == 'object':
685-
shared_definitions += generate_type_definition(def_properties, indent_level, shared_defs)
686-
elif 'enum' in def_properties:
697+
if def_properties.get("type") == "object":
698+
shared_definitions += generate_type_definition(
699+
def_properties, indent_level, shared_defs
700+
)
701+
elif "enum" in def_properties:
687702
# Enum type
688-
shared_definitions += " | ".join([f'"{enum_value}"' for enum_value in def_properties['enum']])
703+
shared_definitions += " | ".join(
704+
[f'"{enum_value}"' for enum_value in def_properties["enum"]]
705+
)
689706
shared_definitions += ";\n"
690707
return shared_definitions
691708

692709
def generate_schema_from_functions(functions, namespace="functions") -> str:
693-
schema = "// Supported function definitions that should be called when necessary.\n"
710+
schema = (
711+
"// Supported function definitions that should be called when necessary.\n"
712+
)
694713
schema += f"namespace {namespace} {{\n\n"
695714

696715
# Generate shared definitions
@@ -706,10 +725,10 @@ def generate_schema_from_functions(functions, namespace="functions") -> str:
706725
description = function.get("description", "")
707726
parameters = function.get("parameters", {})
708727
required_params = parameters.get("required", [])
709-
728+
710729
schema += f" // {description}\n"
711730
schema += f" type {function_name} = (_: {{\n"
712-
731+
713732
for param_name, param in parameters.get("properties", {}).items():
714733
param_description = param.get("description", "")
715734
param_type = generate_type_definition(param, 2, shared_definitions)
@@ -733,13 +752,18 @@ def prepare_messages_for_inference(
733752
role="system", content=generate_schema_from_functions(functions)
734753
)
735754
)
736-
755+
737756
if tools is not None:
738757
all_messages.append(
739758
llama_types.ChatCompletionRequestSystemMessage(
740-
role="system", content=generate_schema_from_functions(
741-
[tool["function"] for tool in tools if tool["type"] == "function"]
742-
)
759+
role="system",
760+
content=generate_schema_from_functions(
761+
[
762+
tool["function"]
763+
for tool in tools
764+
if tool["type"] == "function"
765+
]
766+
),
743767
)
744768
)
745769

@@ -790,7 +814,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
790814
elif "function_call" in msg:
791815
return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>\n"
792816
elif "tool_calls" in msg and len(msg["tool_calls"]) > 0:
793-
for tool_call in msg["tool_calls"]: # NOTE: probably doesn't work with the functionary model
817+
for tool_call in msg[
818+
"tool_calls"
819+
]: # NOTE: probably doesn't work with the functionary model
794820
return f"assistant to={tool_call['id']}:\n{tool_call['function']['arguments']}</s>\n"
795821
elif msg["content"] is None:
796822
return "assistant"
@@ -800,12 +826,14 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
800826
raise ValueError(f"Unsupported role: {msg['role']}")
801827

802828
return "".join([message_to_str(msg) for msg in all_messages])
803-
829+
804830
if tools is not None:
805831
functions = [tool["function"] for tool in tools if tool["type"] == "function"]
806-
832+
807833
if tool_choice is not None:
808-
function_call = tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
834+
function_call = (
835+
tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
836+
)
809837

810838
prompt = prepare_messages_for_inference(messages, functions, tools)
811839

@@ -861,19 +889,27 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
861889
if tool["type"] == "function" and tool["function"]["name"] == function_call:
862890
function_body = tool["function"]["parameters"]
863891
break
864-
892+
865893
if function_body is not None:
866894
try:
867895
with suppress_stdout_stderr(disable=llama.verbose):
868-
grammar_text = llama_grammar.json_schema_to_gbnf(json.dumps(function_body))
869-
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.json_schema_to_gbnf(json.dumps(function_body)))
896+
grammar_text = llama_grammar.json_schema_to_gbnf(
897+
json.dumps(function_body)
898+
)
899+
grammar = llama_grammar.LlamaGrammar.from_string(
900+
llama_grammar.json_schema_to_gbnf(json.dumps(function_body))
901+
)
870902
print(grammar_text)
871903
except Exception as e:
872904
if llama.verbose:
873-
print("Failed to parse function body as JSON schema, falling back to default grammar")
905+
print(
906+
"Failed to parse function body as JSON schema, falling back to default grammar"
907+
)
874908
print(e)
875909
with suppress_stdout_stderr(disable=llama.verbose):
876-
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
910+
grammar = llama_grammar.LlamaGrammar.from_string(
911+
llama_grammar.JSON_GBNF
912+
)
877913
else:
878914
with suppress_stdout_stderr(disable=llama.verbose):
879915
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
@@ -929,9 +965,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
929965
"function": {
930966
"name": function_call, 10000
931967
"arguments": completion["choices"][0]["text"],
932-
}
968+
},
933969
}
934-
]
970+
],
935971
},
936972
"finish_reason": "tool_calls",
937973
}

llama_cpp/server/app.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131

3232
# Disable warning for model and model_alias settings
33-
BaseSettings.model_config['protected_namespaces'] = ()
33+
BaseSettings.model_config["protected_namespaces"] = ()
3434

3535

3636
class Settings(BaseSettings):
@@ -68,7 +68,9 @@ class Settings(BaseSettings):
6868
description="Use mlock.",
6969
)
7070
# Context Params
71-
seed: int = Field(default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random.")
71+
seed: int = Field(
72+
default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."
73+
)
7274
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
7375
n_batch: int = Field(
7476
default=512, ge=1, description="The batch size to use per eval."
@@ -83,30 +85,16 @@ class Settings(BaseSettings):
8385
ge=0,
8486
description="The number of threads to use when batch processing.",
8587
)
86-
rope_scaling_type: int = Field(
87-
default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED
88-
)
89-
rope_freq_base: float = Field(
90-
default=0.0, description="RoPE base frequency"
91-
)
88+
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED)
89+
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
9290
rope_freq_scale: float = Field(
9391
default=0.0, description="RoPE frequency scaling factor"
9492
)
95-
yarn_ext_factor: float = Field(
96-
default=-1.0
97-
)
98-
yarn_attn_factor: float = Field(
99-
default=1.0
100-
)
101-
yarn_beta_fast: float = Field(
102-
default=32.0
103-
)
104-
yarn_beta_slow: float = Field(
105-
default=1.0
106-
)
107-
yarn_orig_ctx: int = Field(
108-
default=0
109-
)
93+
yarn_ext_factor: float = Field(default=-1.0)
94+
yarn_attn_factor: float = Field(default=1.0)
95+
yarn_beta_fast: float = Field(default=32.0)
96+
yarn_beta_slow: float = Field(default=1.0)
97+
yarn_orig_ctx: int = Field(default=0)
11098
mul_mat_q: bool = Field(
11199
default=True, description="if true, use experimental mul_mat_q kernels"
112100
)
@@ -122,7 +110,7 @@ class Settings(BaseSettings):
122110
# LoRA Params
123111
lora_base: Optional[str] = Field(
124112
default=None,
125-
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model."
113+
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.",
126114
)
127115
lora_path: Optional[str] = Field(
128116
default=None,
@@ -384,7 +372,9 @@ def create_app(settings: Optional[Settings] = None):
384372
chat_handler = None
385373
if settings.chat_format == "llava-1-5":
386374
assert settings.clip_model_path is not None
387-
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(clip_model_path=settings.clip_model_path, verbose=settings.verbose)
375+
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
376+
clip_model_path=settings.clip_model_path, verbose=settings.verbose
377+
)
388378
##
389379

390380
llama = llama_cpp.Llama(
@@ -587,9 +577,10 @@ async def get_event_publisher(
587577

588578
grammar = Field(
589579
default=None,
590-
description="A CBNF grammar (as string) to be used for formatting the model's output."
580+
description="A CBNF grammar (as string) to be used for formatting the model's output.",
591581
)
592582

583+
593584
class CreateCompletionRequest(BaseModel):
594585
prompt: Union[str, List[str]] = Field(
595586
default="", description="The prompt to generate completions for."
@@ -690,7 +681,8 @@ async def create_completion(
690681
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
691682

692683
iterator_or_completion: Union[
693-
llama_cpp.CreateCompletionResponse, Iterator[llama_cpp.CreateCompletionStreamResponse]
684+
llama_cpp.CreateCompletionResponse,
685+
Iterator[llama_cpp.CreateCompletionStreamResponse],
694686
] = await run_in_threadpool(llama, **kwargs)
695687

696688
if isinstance(iterator_or_completion, Iterator):
@@ -748,7 +740,9 @@ class ChatCompletionRequestMessage(BaseModel):
748740
role: Literal["system", "user", "assistant", "function"] = Field(
749741
default="user", description="The role of the message."
750742
)
751-
content: Optional[str] = Field(default="", description="The content of the message.")
743+
content: Optional[str] = Field(
744+
default="", description="The content of the message."
745+
)
752746

753747

754748
class CreateChatCompletionRequest(BaseModel):
@@ -770,9 +764,10 @@ class CreateChatCompletionRequest(BaseModel):
770764
tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field(
771765
default=None,
772766
description="A tool to apply to the generated completions.",
773-
) # TODO: verify
767+
) # TODO: verify
774768
max_tokens: Optional[int] = Field(
775-
default=None, description="The maximum number of tokens to generate. Defaults to inf"
769+
default=None,
770+
description="The maximum number of tokens to generate. Defaults to inf",
776771
)
777772
temperature: float = temperature_field
778773
top_p: float = top_p_field

0 commit comments

Comments
 (0)
0