8000 models - correct tool result content (#154) · jer96/sdk-python@76cd7ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 76cd7ba

Browse files
authored
models - correct tool result content (strands-agents#154)
1 parent af25f98 commit 76cd7ba

File tree

9 files changed

+194
-254
lines changed

9 files changed

+194
-254
lines changed

src/strands/models/anthropic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import base64
7+
import json
78
import logging
89
import mimetypes
910
from typing import Any, Iterable, Optional, TypedDict, cast
@@ -145,7 +146,11 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
145146
if "toolResult" in content:
146147
return {
147148
"content": [
148-
self._format_request_message_content(cast(ContentBlock, tool_result_content))
149+
self._format_request_message_content(
150+
{"text": json.dumps(tool_result_content["json"])}
151+
if "json" in tool_result_content
152+
else cast(ContentBlock, tool_result_content)
153+
)
149154
for tool_result_content in content["toolResult"]["content"]
150155
],
151156
"is_error": content["toolResult"]["status"] == "error",

src/strands/models/litellm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def get_config(self) -> LiteLLMConfig:
6767
return cast(LiteLLMModel.LiteLLMConfig, self.config)
6868

6969
@override
70-
@staticmethod
71-
def format_request_message_content(content: ContentBlock) -> dict[str, Any]:
70+
@classmethod
71+
def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
7272
"""Format a LiteLLM content block.
7373
7474
Args:
@@ -96,4 +96,4 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]:
9696
},
9797
}
9898

99-
return OpenAIModel.format_request_message_content(content)
99+
return super().format_request_message_content(content)

src/strands/models/llamaapi.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
import logging
1010
import mimetypes
11-
from typing import Any, Iterable, Optional
11+
from typing import Any, Iterable, Optional, cast
1212

1313
import llama_api_client
1414
from llama_api_client import LlamaAPIClient
@@ -139,18 +139,30 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any
139139
Returns:
140140
Llama API formatted tool message.
141141
"""
142+
contents = cast(
143+
list[ContentBlock],
144+
[
145+
{"text": json.dumps(content["json"])} if "json" in content else content
146+
for content in tool_result["content"]
147+
],
148+
)
149+
142150
return {
143151
"role": "tool",
144152
"tool_call_id": tool_result["toolUseId"],
145-
"content": json.dumps(
146-
{
147-
"content": tool_result["content"],
148-
"status": tool_result["status"],
149-
}
150-
),
153+
"content": [self._format_request_message_content(content) for content in contents],
151154
}
152155

153156
def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
157+
"""Format a LlamaAPI compatible messages array.
158+
159+
Args:
160+
messages: List of message objects to be processed by the model.
161+
system_prompt: System prompt to provide context to the model.
162+
163+
Returns:
164+
An LlamaAPI compatible messages array.
165+
"""
154166
formatted_messages: list[dict[str, Any]]
155167
formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
156168

src/strands/models/ollama.py

Lines changed: 114 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55

66
import json
77
import logging
8-
from typing import Any, Iterable, Optional, Union
8+
from typing import Any, Iterable, Optional, cast
99

1010
from ollama import Client as OllamaClient
1111
from typing_extensions import TypedDict, Unpack, override
1212

13-
from ..types.content import ContentBlock, Message, Messages
14-
from ..types.media import DocumentContent, ImageContent
13+
from ..types.content import ContentBlock, Messages
1514
from ..types.models import Model
1615
from ..types.streaming import StopReason, StreamEvent
1716
from ..types.tools import ToolSpec
@@ -92,35 +91,31 @@ def get_config(self) -> OllamaConfig:
9291
"""
9392
return self.config
9493

95-
@override
96-
def format_request(
97-
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
98-
) -> dict[str, Any]:
99-
"""Format an Ollama chat streaming request.
94+
def _format_request_message_contents(self, role: str, content: ContentBlock) -> list[dict[str, Any]]:
95+
"""Format Ollama compatible message contents.
96+
97+
Ollama doesn't support an array of contents, so we must flatten everything into separate message blocks.
10098
10199
Args:
102-
messages: List of message objects to be processed by the model.
103-
tool_specs: List of tool specifications to make available to the model.
104-
system_prompt: System prompt to provide context to the model.
100+
role: E.g., user.
101+
content: Content block to format.
105102
106103
Returns:
107-
An Ollama chat streaming request.
104+
Ollama formatted message contents.
108105
109106
Raises:
110-
TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible
111-
format.
107+
TypeError: If the content block type cannot be converted to an Ollama-compatible format.
112108
"""
109+
if "text" in content:
110+
return [{"role": role, "content": content["text"]}]
113111

114-
def format_message(message: Message, content: ContentBlock) -> dict[str, Any]:
115-
if "text" in content:
116-
return {"role": message["role"], "content": content["text"]}
112+
if "image" in content:
113+
return [{"role": role, "images": [content["image"]["source"]["bytes"]]}]
117114

118-
if "image" in content:
119-
return {"role": message["role"], "images": [content["image"]["source"]["bytes"]]}
120-
121-
if "toolUse" in content:
122-
return {
123-
"role": "assistant",
115+
if "toolUse" in content:
116+
return [
117+
{
118+
"role": role,
124119
"tool_calls": [
125120
{
126121
"function": {
@@ -130,45 +125,63 @@ def format_message(message: Message, content: ContentBlock) -> dict[str, Any]:
130125
}
131126
],
132127
}
128+
]
129+
130+
if "toolResult" in content:
131+
return [
132+
formatted_tool_result_content
133+
for tool_result_content in content["toolResult"]["content"]
134+
for formatted_tool_result_content in self._format_request_message_contents(
135+
"tool",
136+
(
137+
{"text": json.dumps(tool_result_content["json"])}
138+
if "json" in tool_result_content
139+
else cast(ContentBlock, tool_result_content)
140+
),
141+
)
142+
]
133143

134-
if "toolResult" in content:
135-
result_content: Union[str, ImageContent, DocumentContent, Any] = None
136-
result_images = []
137-
for tool_result_content in content["toolResult"]["content"]:
138-
if "text" in tool_result_content:
139-
result_content = tool_result_content["text"]
140-
elif "json" in tool_result_content:
141-
result_content = tool_result_content["json"]
142-
elif "image" in tool_result_content:
143-
result_content = "see images"
144-
result_images.append(tool_result_content["image"]["source"]["bytes"])
145-
else:
146-
result_content = content["toolResult"]["content"]
144+
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
147145

148-
return {
149-
"role": "tool",
150-
"content": json.dumps(
151-
{
152-
"name": content["toolResult"]["toolUseId"],
153-
"result": result_content,
154-
"status": content["toolResult"]["status"],
155-
}
156-
),
157-
**({"images": result_images} if result_images else {}),
158-
}
146+
def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
147+
"""Format an Ollama compatible messages array.
159148
160-
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
149+
Args:
150+
messages: List of message objects to be processed by the model.
151+
system_prompt: System prompt to provide context to the model.
161152
162-
def format_messages() -> list[dict[str, Any]]:
163-
return [format_message(message, content) for message in messages for content in message["content"]]
153+
Returns:
154+
An Ollama compatible messages array.
155+
"""
156+
system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
164157

165-
formatted_messages = format_messages()
158+
return system_message + [
159+
formatted_message
160+
for message in messages
161+
for content in message["content"]
162+
for formatted_message in self._format_request_message_contents(message["role"], content)
163+
]
166164

165+
@override
166+
def format_request(
167+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
168+
) -> dict[str, Any]:
169+
"""Format an Ollama chat streaming request.
170+
171+
Args:
172+
messages: List of message objects to be processed by the model.
173+
tool_specs: List of tool specifications to make available to the model.
174+
system_prompt: System prompt to provide context to the model.
175+
176+
Returns:
177+
An Ollama chat streaming request.
178+
179+
Raises:
180+
TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible
181+
format.
182+
"""
167183
return {
168-
"messages": [
169-
*([{"role": "system", "content": system_prompt}] if system_prompt else []),
170-
*formatted_messages,
171-
],
184+
"messages": self._format_request_messages(messages, system_prompt),
172185
"model": self.config["model_id"],
173186
"options": {
174187
**(self.config.get("options") or {}),
@@ -217,52 +230,54 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
217230
RuntimeError: If chunk_type is not recognized.
218231
This error should never be encountered as we control chunk_type in the stream method.
219232
"""
220-
if event["chunk_type"] == "message_start":
221-
return {"messageStart": {"role": "assistant"}}
222-
223-
if event["chunk_type"] == "content_start":
224-
if event["data_type"] == "text":
225-
return {"contentBlockStart": {"start": {}}}
226-
227-
tool_name = event["data"].function.name
228-
return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}}
229-
230-
if event["chunk_type"] == "content_delta":
231-
if event["data_type"] == "text":
232-
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
233-
234-
tool_arguments = event["data"].function.arguments
235-
return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}}
236-
237-
if event["chunk_type"] == "content_stop":
238-
return {"contentBlockStop": {}}
239-
240-
if event["chunk_type"] == "message_stop":
241-
reason: StopReason
242-
if event["data"] == "tool_use":
243-
reason = "tool_use"
244-
elif event["data"] == "length":
245-
reason = "max_tokens"
246-
else:
247-
reason = "end_turn"
248-
249-
return {"messageStop": {"stopReason": reason}}
250-
251-
if event["chunk_type"] == "metadata":
252-
return {
253-
"metadata": {
254-
"usage": {
255-
"inputTokens": event["data"].eval_count,
256-
"outputTokens": event["data"].prompt_eval_count,
257-
"totalTokens": event["data"].eval_count + event["data"].prompt_eval_count,
258-
},
259-
"metrics": {
260-
"latencyMs": event["data"].total_duration / 1e6,
233+
match event["chunk_type"]:
234+
case "message_start":
235+
return {"messageStart": {"role": "assistant"}}
236+
237+
case "content_start":
238+
if event["data_type"] == "text":
239+
return {"contentBlockStart": {"start": {}}}
240+
241+
tool_name = event["data"].function.name
242+
return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}}
243+
244+
case "content_delta":
245+
if event["data_type"] == "text":
246+
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
247+
248+
tool_arguments = event["data"].function.arguments
249+
return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}}
250+
251+
case "content_stop":
252+
return {"contentBlockStop": {}}
253+
254+
case "message_stop":
255+
reason: StopReason
256+
if event["data"] == "tool_use":
257+
reason = "tool_use"
258+
elif event["data"] == "length":
259+
reason = "max_tokens"
260+
else:
261+
reason = "end_turn"
262+
263+
return {"messageStop": {"stopReason": reason}}
264+
265+
case "metadata":
266+
return {
267+
"metadata": {
268+
"usage": {
269+
"inputTokens": event["data"].eval_count,
270+
"outputTokens": event["data"].prompt_eval_count,
271+
"totalTokens": event["data"].eval_count + event["data"].prompt_eval_count,
272+
},
273+
"metrics": {
274+
"latencyMs": event["data"].total_duration / 1e6,
275+
},
261276
},
262-
},
263-
}
277+
}
264278

265-
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
279+
case _:
280+
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
266281

267282
@override
268283
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:

0 commit comments

Comments
 (0)
0