diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 17575c700..1f1998c3f 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -7,6 +7,7 @@ import dataclasses import random import string +import warnings from contextlib import ExitStack from typing import ( @@ -28,9 +29,7 @@ import numpy as np import numpy.typing as npt -import llama_cpp.llama as llama -import llama_cpp.llama_types as llama_types -import llama_cpp.llama_grammar as llama_grammar +from llama_cpp import llama, llama_grammar, llama_types from ._logger import logger from ._utils import suppress_stdout_stderr, Singleton @@ -3373,6 +3372,155 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler): ) +def _accumulate_chunks( + chunks_iterator: Iterator[llama_types.CreateCompletionStreamResponse], + chunks_list: List[llama_types.CreateCompletionStreamResponse], +) -> Iterator[llama_types.CreateCompletionStreamResponse]: + for chunk in chunks_iterator: + chunks_list.append(chunk) + yield chunk + + +def _convert_chunks_to_completion( + chunks: List[llama_types.CreateCompletionStreamResponse], +) -> llama_types.CreateCompletionResponse: + """Convert a list of completion chunks to a completion.""" + # Accumulate completion response values + text: str = "" + finish_reason: Optional[str] = None + logprobs: Optional[llama_types.CompletionLogprobs] = None + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + completion_id: Optional[str] = None + completion_model: Optional[str] = None + completion_created: Optional[int] = None + for chunk in chunks: + # Extract the id, model, and created values from the first chunk + if completion_id is None: + completion_id = chunk["id"] + completion_model = chunk["model"] + completion_created = chunk["created"] + # Extract the usage if present in the chunk + usage = chunk.get("usage") + if usage: + prompt_tokens += usage.get("prompt_tokens", 0) + completion_tokens += usage.get("completion_tokens", 0) + total_tokens += usage.get("total_tokens", 0) + # Accumulate the chunk text + choice = chunk["choices"][0] + text += choice.get("text", "") + # Extract the finish_reason and logprobs if present in the chunk + if choice.get("finish_reason"): + finish_reason = choice["finish_reason"] + if choice.get("logprobs"): + logprobs = choice["logprobs"] + # Create the completion response + completion: llama_types.CreateCompletionResponse = { + "id": completion_id or "unknown_id", + "object": "text_completion", + "created": completion_created or 0, + "model": completion_model or "unknown_model", + "choices": [ + { + "text": text, + "index": 0, + "logprobs": logprobs, # TODO: Improve accumulation of logprobs + "finish_reason": finish_reason, # type: ignore[typeddict-item] + } + ], + } + # Add usage section if present in the chunks + if (prompt_tokens + completion_tokens + total_tokens) > 0: + completion["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + return completion + + +def _stream_tool_calls( + llama: llama.Llama, + prompt: str, + tools: List[llama_types.ChatCompletionTool], + tool_name: str, + completion_kwargs: dict[str, Any], + follow_up_gbnf_tool_grammar: str, +) -> Iterator[llama_types.CreateChatCompletionStreamResponse]: + # Generate a tool call completions + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + completions: List[llama_types.CreateCompletionResponse] = [] + completions_tool_name: List[str] = [] + finish_reason_chat_chunk = None + while tool is not None and len(completions) <= 16: + # Generate the parameter values for the selected tool + prompt += f"functions.{tool_name}:\n" + try: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + ) + except Exception as e: + warnings.warn( + f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}", + category=RuntimeWarning, + stacklevel=2, + ) + grammar = llama_grammar.LlamaGrammar.from_string( + llama_grammar.JSON_GBNF, verbose=llama.verbose + ) + completion_or_chunks = llama.create_completion( + prompt=prompt, + **{ + **completion_kwargs, + "max_tokens": None, + "grammar": grammar, + }, + ) + chunks: List[llama_types.CreateCompletionResponse] = [] + chat_chunks = _convert_completion_to_chat_function( + tool_name, + _accumulate_chunks(completion_or_chunks, chunks), # type: ignore[arg-type] + stream=True, + ) + for chat_chunk in chat_chunks: + # Don't return the finish_reason chunk + if chat_chunk["choices"] and chat_chunk["choices"][0].get("finish_reason"): + finish_reason_chat_chunk = chat_chunk + break + # Update this tool call's index + if chat_chunk["choices"] and chat_chunk["choices"][0]["delta"].get("tool_calls"): + chat_chunk["choices"][0]["delta"]["tool_calls"][0]["index"] = len(completions) + yield chat_chunk + completion = _convert_chunks_to_completion(chunks) + completions.append(completion) + completions_tool_name.append(tool_name) + prompt += completion["choices"][0]["text"] + prompt += "\n" + # Determine whether to call another tool or stop + response = cast( + llama_types.CreateCompletionResponse, + llama.create_completion( + prompt=prompt, + **{ + **completion_kwargs, + "temperature": 0, + "stream": False, + "stop": [*completion_kwargs["stop"], ":", ""], + "max_tokens": None, + "grammar": llama_grammar.LlamaGrammar.from_string( + follow_up_gbnf_tool_grammar, verbose=llama.verbose + ), + }, + ), + ) + tool_name = response["choices"][0]["text"][len("functions.") :] + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + # Yield the finish_reason chunk + if finish_reason_chat_chunk is not None: + yield finish_reason_chat_chunk + + @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( llama: llama.Llama, @@ -3402,7 +3550,7 @@ def chatml_function_calling( grammar: Optional[llama.LlamaGrammar] = None, logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, - **kwargs, # type: ignore + **kwargs: Any, ) -> Union[ llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse], @@ -3416,18 +3564,21 @@ def chatml_function_calling( "{% if tool_calls %}" "\n\nYou have access to the following functions:\n" "{% for tool in tools %}" + '\n{% if tool.function.get("description") %}/* {{ tool.function.description | trim }} */{% endif %}' "\nfunctions.{{ tool.function.name }}:\n" "{{ tool.function.parameters | tojson }}" "\n{% endfor %}" - "\n\nYou can respond to users messages with either a single message or one or more function calls." - "\n\nTo respond with a message begin the message with 'message:', use the following format:" + "\nYou must respond to user messages with either a single message or with one or more function calls." + "\n\nTo respond with a message use the following format:" "\n\nmessage:" "\n" - "\n\nTo respond with one or more function calls begin the message with 'functions.:', use the following format:" - "\n\nfunctions.:" + "\n\nTo respond with one or more function calls use the following format:" + "\n\n" + "\nfunctions.:" '\n{ "arg1": "value1", "arg2": "value2" }' "\nfunctions.:" '\n{ "arg1": "value1", "arg2": "value2" }' + "\n" "{% endif %}" "<|im_end|>\n" "{% endif %}" @@ -3438,7 +3589,7 @@ def chatml_function_calling( "{% endif %}" # Assistant message "{% if message.role == 'assistant' %}" - ## Reglar message + ## Regular message "{% if message.content and message.content | length > 0 %}" "{% if tool_calls %}" "message:\n" @@ -3465,35 +3616,55 @@ def chatml_function_calling( # Convert legacy functions to tools if functions is not None: - tools = [ - { - "type": "function", - "function": function, - } - for function in functions - ] + tools = [{"type": "function", "function": function} for function in functions] # Convert legacy function_call to tool_choice if function_call is not None: - if isinstance(function_call, str) and ( - function_call == "none" or function_call == "auto" - ): + if isinstance(function_call, str) and (function_call in ("none", "auto")): tool_choice = function_call if isinstance(function_call, dict) and "name" in function_call: - tool_choice = { - "type": "function", - "function": { - "name": function_call["name"], - }, - } + tool_choice = {"type": "function", "function": {"name": function_call["name"]}} + # Collect the llama.create_completion keyword arguments so we don't have to repeat these with + # each completion call stop = ( [stop, "<|im_end|>"] if isinstance(stop, str) - else stop + ["<|im_end|>"] if stop else ["<|im_end|>"] + else [*stop, "<|im_end|>"] + if stop + else ["<|im_end|>"] ) + grammar = ( # It is assumed the grammar applies to messages only, not tool calls + grammar + if grammar is not None + else ( + _grammar_for_response_format(response_format) + if response_format is not None and response_format["type"] == "json_object" + else None + ) + ) + completion_kwargs = { + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "typical_p": typical_p, + "stream": stream, + "stop": stop, + "max_tokens": max_tokens, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "repeat_penalty": repeat_penalty, + "tfs_z": tfs_z, + "mirostat_mode": mirostat_mode, + "mirostat_tau": mirostat_tau, + "mirostat_eta": mirostat_eta, + "model": model, + "logits_processor": logits_processor, + "grammar": grammar, + } - # Case 1: No tool choice by user + # Case 1: No tool use if ( tool_choice is None or (isinstance(tool_choice, str) and tool_choice == "none") @@ -3501,316 +3672,184 @@ def chatml_function_calling( or len(tools) == 0 ): prompt = template_renderer.render( - messages=messages, - tools=[], - tool_calls=None, - add_generation_prompt=True, + messages=messages, tools=[], tool_calls=None, add_generation_prompt=True ) - - if response_format is not None and response_format["type"] == "json_object": - grammar = _grammar_for_response_format(response_format) - return _convert_completion_to_chat( llama.create_completion( prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=stream, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=grammar, + **completion_kwargs, # type: ignore[arg-type] logprobs=top_logprobs if logprobs else None, ), stream=stream, ) - # Case 2: Tool choice by user + # Ensure there is a system prompt to attach the tool metadata to + if not any(message["role"] == "system" for message in messages): + messages = [*messages, {"role": "system", "content": ""}] + + # Case 2: Automatic or fixed tool choice + # Case 2 step 1: Determine whether to respond with a message or a tool call + assert (isinstance(tool_choice, str) and tool_choice == "auto") or isinstance(tool_choice, dict) if isinstance(tool_choice, dict): - tool_name = tool_choice["function"]["name"] - tool = next( - (tool for tool in tools if tool["function"]["name"] == tool_name), None + tools = [t for t in tools if t["function"]["name"] == tool_choice["function"]["name"]] + assert tools + function_names = " | ".join([f'''"functions.{t['function']['name']}:"''' for t in tools]) + prompt = template_renderer.render( + messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True + ) + initial_gbnf_tool_grammar = ( + ( + 'root ::= "" "\\n" functions | "message:"\n' + f"functions ::= {function_names}\n" ) - if tool is None: - raise ValueError(f"Tool with name '{tool_name}' not found in tools") + if tool_choice == "auto" + else f'root ::= "" "\\n" functions\nfunctions ::= {function_names}\n' + ) + completion = cast( + llama_types.CreateCompletionResponse, + llama.create_completion( + prompt=prompt, + **{ # type: ignore[arg-type] + **completion_kwargs, + "temperature": 0, + "stream": False, + "stop": [":"], + "max_tokens": None, + "grammar": llama_grammar.LlamaGrammar.from_string( + initial_gbnf_tool_grammar, verbose=llama.verbose + ), + }, + ), + ) + text = completion["choices"][0]["text"] + tool_name = None if text.startswith("message") else text.split("\n")[-1][len("functions.") :] + + # Case 2 step 2A: Respond with a message + if tool_name is None: prompt = template_renderer.render( - messages=messages, - tools=tools, - tool_calls=True, - add_generation_prompt=True, + messages=messages, tools=[], tool_calls=None, add_generation_prompt=True ) + return _convert_completion_to_chat( + llama.create_completion( + prompt=prompt, + **completion_kwargs, # type: ignore[arg-type] + logprobs=top_logprobs if logprobs else None, + ), + stream=stream, + ) + + # Case 2 step 2B: One or more function calls + follow_up_gbnf_tool_grammar = ( + 'root ::= functions | "" | "<|im_end|>"\n' + f"functions ::= {function_names}\n" + ) + prompt += "\n" + if stream: + return _stream_tool_calls( + llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar + ) + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + completions: List[llama_types.CreateCompletionResponse] = [] + completions_tool_name: List[str] = [] + while tool is not None and len(completions) <= 16: + # Generate the parameter values for the selected tool prompt += f"functions.{tool_name}:\n" try: grammar = llama_grammar.LlamaGrammar.from_json_schema( json.dumps(tool["function"]["parameters"]), verbose=llama.verbose ) except Exception as e: + warnings.warn( + f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}", + category=RuntimeWarning, + stacklevel=2, + ) grammar = llama_grammar.LlamaGrammar.from_string( llama_grammar.JSON_GBNF, verbose=llama.verbose ) - if llama.verbose: - print( - "Failed to parse function body as JSON schema, falling back to default grammar" - ) - print(e) completion_or_chunks = llama.create_completion( prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=stream, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=grammar, - ) - return _convert_completion_to_chat_function( - tool_name, completion_or_chunks, stream + **{ # type: ignore[arg-type] + **completion_kwargs, + "max_tokens": None, + "grammar": grammar, + }, ) - - # Case 3: Automatic tool choice - assert isinstance(tool_choice, str) and tool_choice == "auto" - function_names = " | ".join( - [f'''"functions.{tool['function']['name']}:"''' for tool in tools] - ) - initial_gbnf_tool_grammar = ( - """root ::= functions | "message:"\n""" - f"""functions ::= {function_names}\n""" - ) - follow_up_gbnf_tool_grammar = ( - """root ::= functions | "<|im_end|>"\n""" - f"""functions ::= {function_names}\n""" - ) - prompt = template_renderer.render( - messages=messages, - tools=tools, - tool_calls=True, - add_generation_prompt=True, - ) - completion_or_chunks = llama.create_completion( - prompt=prompt, - temperature=0, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=False, - stop=[":"], - max_tokens=None, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=llama_grammar.LlamaGrammar.from_string( - initial_gbnf_tool_grammar, verbose=llama.verbose - ), - ) - completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore - text = completion["choices"][0]["text"] - if "message" in text: - return _convert_completion_to_chat( + completion = cast(llama_types.CreateCompletionResponse, completion_or_chunks) + completions.append(completion) + completions_tool_name.append(tool_name) + prompt += completion["choices"][0]["text"] + prompt += "\n" + # Determine whether to call another tool or stop + response = cast( + llama_types.CreateCompletionResponse, llama.create_completion( - prompt=prompt + "message:\n", - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=stream, - stop=["<|im_end|>"], - logprobs=top_logprobs if logprobs else None, - max_tokens=None, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=llama_grammar.LlamaGrammar.from_string( - follow_up_gbnf_tool_grammar, verbose=llama.verbose - ), + prompt=prompt, + **{ # type: ignore[arg-type] + **completion_kwargs, + "temperature": 0, + "stream": False, + "stop": [*completion_kwargs["stop"], ":", ""], # type: ignore[misc] + "max_tokens": None, + "grammar": llama_grammar.LlamaGrammar.from_string( + follow_up_gbnf_tool_grammar, verbose=llama.verbose + ), + }, ), - stream=stream, ) - - # One or more function calls - tool_name = text[len("functions.") :] - tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) - if not stream: - completions: List[llama_types.CreateCompletionResponse] = [] - completions_tool_name: List[str] = [] - while tool is not None: - prompt += f"functions.{tool_name}:\n" - try: - grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(tool["function"]["parameters"]), verbose=llama.verbose - ) - except Exception as e: - grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=llama.verbose - ) - if llama.verbose: - print( - "Failed to parse function body as JSON schema, falling back to default grammar" - ) - print(e) - completion_or_chunks = llama.create_completion( - prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=False, - stop=stop, - max_tokens=None, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=grammar, - ) - completion_or_chunks = cast( - llama_types.CreateCompletionResponse, completion_or_chunks - ) - completions.append(completion_or_chunks) - completions_tool_name.append(tool_name) - prompt += completion_or_chunks["choices"][0]["text"] - prompt += "\n" - - response = llama.create_completion( - prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=False, - stop=stop, - max_tokens=None, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=llama_grammar.LlamaGrammar.from_string( - follow_up_gbnf_tool_grammar, verbose=llama.verbose - ), - ) - response = cast(llama_types.CreateCompletionResponse, response) - - tool_name = response["choices"][0]["text"][len("functions.") :] - tool = next( - (tool for tool in tools if tool["function"]["name"] == tool_name), None - ) - - # Merge completions - function_call_dict: Union[ - Dict[str, str], - Dict[ - Literal["function_call"], - llama_types.ChatCompletionRequestAssistantMessageFunctionCall, - ], - ] = ( + tool_name = response["choices"][0]["text"][len("functions.") :] + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + # Merge the completions into a single chat completion + chat_completion: llama_types.CreateChatCompletionResponse = { + "id": "chat" + completion["id"], + "object": "chat.completion", + "created": completion["created"], + "model": completion["model"], + "choices": [ { - "function_call": { - "name": tool_name, - "arguments": completions[0]["choices"][0]["text"], - } - } - if len(completions) == 1 - else {} - ) - return { - "id": "chat" + completion["id"], - "object": "chat.completion", - "created": completion["created"], - "model": completion["model"], - "choices": [ - { - "finish_reason": "tool_calls", - "index": 0, - "logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]), - "message": { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": "call_" - + f"_{i}_" - + tool_name - + "_" - + completion["id"], - "type": "function", - "function": { - "name": tool_name, - "arguments": completion["choices"][0]["text"], - }, - } - for i, (tool_name, completion) in enumerate( - zip(completions_tool_name, completions) - ) - ], - **function_call_dict, - }, - } - ], - "usage": { - "completion_tokens": sum( - ( - completion["usage"]["completion_tokens"] - if "usage" in completion - else 0 - ) - for completion in completions - ), - "prompt_tokens": sum( - completion["usage"]["prompt_tokens"] if "usage" in completion else 0 - for completion in completions - ), - "total_tokens": sum( - completion["usage"]["total_tokens"] if "usage" in completion else 0 - for completion in completions + "finish_reason": "tool_calls", + "index": 0, + "logprobs": _convert_text_completion_logprobs_to_chat( + completion["choices"][0]["logprobs"] ), - }, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_" + f"_{i}_" + tool_name + "_" + completion["id"], + "type": "function", + "function": { + "name": tool_name, + "arguments": completion["choices"][0]["text"], + }, + } + for i, (tool_name, completion) in enumerate( + zip(completions_tool_name, completions) + ) + ], + }, + } + ], + "usage": { + "completion_tokens": sum( + (completion["usage"]["completion_tokens"] if "usage" in completion else 0) + for completion in completions + ), + "prompt_tokens": sum( + completion["usage"]["prompt_tokens"] if "usage" in completion else 0 + for completion in completions + ), + "total_tokens": sum( + completion["usage"]["total_tokens"] if "usage" in completion else 0 + for completion in completions + ), + }, + } + if len(completions) == 1: + single_function_call: llama_types.ChatCompletionResponseFunctionCall = { + "name": tool_name, + "arguments": completions[0]["choices"][0]["text"], } - - raise ValueError("Automatic streaming tool choice is not supported") + chat_completion["choices"][0]["message"]["function_call"] = single_function_call + return chat_completion diff --git a/pyproject.toml b/pyproject.toml index 9983ef777..1f0aab57b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,8 @@ test = [ "sse-starlette>=1.6.1", "starlette-context>=0.3.6,<0.4", "pydantic-settings>=2.0.1", - "huggingface-hub>=0.23.0" + "huggingface-hub>=0.23.0", + "typeguard>=4.2.1", ] dev = [ "black>=23.3.0", diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py index f031bf72b..42bbac1f5 100644 --- a/tests/test_llama_chat_format.py +++ b/tests/test_llama_chat_format.py @@ -1,14 +1,29 @@ import json +import os +import platform +from collections.abc import Iterator +from typing import cast +import pytest import jinja2 +from typeguard import ForwardRefPolicy, check_type from llama_cpp import ( ChatCompletionRequestUserMessage, + Llama, + llama_chat_format, + llama_supports_gpu_offload, + llama_types ) -import llama_cpp.llama_types as llama_types -import llama_cpp.llama_chat_format as llama_chat_format - from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter +from llama_cpp.llama_types import ( + ChatCompletionRequestMessage, + ChatCompletionTool, + ChatCompletionToolChoiceOption, + CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, +) + def test_mistral_instruct(): chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" @@ -87,3 +102,118 @@ def test_hf_tokenizer_config_str_to_chat_formatter(): ) assert chat_formatter_respoonse.prompt == ("[INST] Hello, world! [/INST]" "") + + +def is_accelerator_available() -> bool: + """Check if an accelerator is available.""" + return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 + + +@pytest.mark.parametrize( + "stream", + [ + pytest.param(True, id="stream=True"), + pytest.param(False, id="stream=False"), + ], +) +@pytest.mark.parametrize( + "tool_choice", + [ + pytest.param("none", id="tool_choice=none"), + pytest.param("auto", id="tool_choice=auto"), + pytest.param( + {"type": "function", "function": {"name": "get_weather"}}, id="tool_choice=fixed" + ), + ], +) +@pytest.mark.parametrize( + "user_prompt_expected_tool_calls", + [ + pytest.param( + ("Is 7 a prime number?", 0), + id="expected_tool_calls=0", + ), + pytest.param( + ("What's the weather like in Paris today?", 1), + id="expected_tool_calls=1", + ), + pytest.param( + ("What's the weather like in Paris today? What about New York?", 2), + id="expected_tool_calls=2", + ), + ], +) +@pytest.mark.parametrize( + "llm_repo_id", + [ + pytest.param("bartowski/Llama-3.2-3B-Instruct-GGUF", id="llama_3.2_3B"), + pytest.param( + "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", + id="llama_3.1_8B", + marks=pytest.mark.skipif( + not is_accelerator_available(), reason="Accelerator not available" + ), + ), + ], +) +@pytest.mark.skipif( + platform.system() == "Darwin" and (os.cpu_count() or 1) < 8, + reason="Insufficient resources on macOS", +) +def test_llama_cpp_python_tool_use( + llm_repo_id: str, + user_prompt_expected_tool_calls: tuple[str, int], + tool_choice: ChatCompletionToolChoiceOption, + stream: bool, +) -> None: + """Test the upgraded chatml-function-calling llama-cpp-python chat handler.""" + user_prompt, expected_tool_calls = user_prompt_expected_tool_calls + if isinstance(tool_choice, dict) and expected_tool_calls == 0: + pytest.skip("Nonsensical") + llm = Llama.from_pretrained( + repo_id=llm_repo_id, + filename="*Q4_K_M.gguf", + n_ctx=4096, + n_gpu_layers=-1, + verbose=False, + chat_format="chatml-function-calling", + ) + messages: list[ChatCompletionRequestMessage] = [{"role": "user", "content": user_prompt}] + tools: list[ChatCompletionTool] = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a location.", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string", "description": "A city name."}}, + }, + }, + } + ] + response = llm.create_chat_completion( + messages=messages, tools=tools, tool_choice=tool_choice, stream=stream + ) + if stream: + response = cast(Iterator[CreateChatCompletionStreamResponse], response) + num_tool_calls = 0 + for chunk in response: + check_type(chunk, CreateChatCompletionStreamResponse) + tool_calls = chunk["choices"][0]["delta"].get("tool_calls") + if isinstance(tool_calls, list): + num_tool_calls = max(tool_call["index"] for tool_call in tool_calls) + 1 + assert num_tool_calls == (expected_tool_calls if tool_choice != "none" else 0) + else: + response = cast(CreateChatCompletionResponse, response) + check_type( + response, CreateChatCompletionResponse, forward_ref_policy=ForwardRefPolicy.IGNORE + ) + if expected_tool_calls == 0 or tool_choice == "none": + assert response["choices"][0]["message"].get("tool_calls") is None + else: + assert len(response["choices"][0]["message"]["tool_calls"]) == expected_tool_calls + assert all( + tool_call["function"]["name"] == tools[0]["function"]["name"] + for tool_call in response["choices"][0]["message"]["tool_calls"] + )