8000 feature: models - openai (#65) · strands-agents/sdk-python@6dda2d8 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6dda2d8

Browse files
authored
feature: models - openai (#65)
1 parent 5f4b68a commit 6dda2d8

File tree

13 files changed

+1024
-772
lines changed

13 files changed

+1024
-772
lines changed

pyproject.toml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ dev = [
5454
"commitizen>=4.4.0,<5.0.0",
5555
"hatch>=1.0.0,<2.0.0",
5656
"moto>=5.1.0,<6.0.0",
57-
"mypy>=0.981,<1.0.0",
57+
"mypy>=1.15.0,<2.0.0",
5858
"pre-commit>=3.2.0,<4.2.0",
5959
"pytest>=8.0.0,<9.0.0",
6060
"pytest-asyncio>=0.26.0,<0.27.0",
@@ -69,15 +69,18 @@ docs = [
6969
litellm = [
7070
"litellm>=1.69.0,<2.0.0",
7171
]
72+
llamaapi = [
73+
"llama-api-client>=0.1.0,<1.0.0",
74+
]
7275
ollama = [
7376
"ollama>=0.4.8,<1.0.0",
7477
]
75-
llamaapi = [
76-
"llama-api-client>=0.1.0,<1.0.0",
78+
openai = [
79+
"openai>=1.68.0,<2.0.0",
7780
]
7881

7982
[tool.hatch.envs.hatch-static-analysis]
80-
features = ["anthropic", "litellm", "llamaapi", "ollama"]
83+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
8184
dependencies = [
8285
"mypy>=1.15.0,<2.0.0",
8386
"ruff>=0.11.6,<0.12.0",
@@ -100,7 +103,7 @@ lint-fix = [
100103
]
101104

102105
[tool.hatch.envs.hatch-test]
103-
features = ["anthropic", "litellm", "llamaapi", "ollama"]
106+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
104107
extra-dependencies = [
105108
"moto>=5.1.0,<6.0.0",
106109
"pytest>=8.0.0,<9.0.0",

src/strands/models/litellm.py

Lines changed: 10 additions & 255 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,19 @@
33
- Docs: https://docs.litellm.ai/
44
"""
55

6-
import json
76
import logging
8-
import mimetypes
9-
from typing import Any, Iterable, Optional, TypedDict
7+
from typing import Any, Optional, TypedDict, cast
108

119
import litellm
1210
from typing_extensions import Unpack, override
1311

14-
from ..types.content import ContentBlock, Messages
15-
from ..types.models import Model
16-
from ..types.streaming import StreamEvent
17-
from ..types.tools import ToolResult, ToolSpec, ToolUse
12+
from ..types.content import ContentBlock
13+
from .openai import OpenAIModel
1814

1915
logger = logging.getLogger(__name__)
2016

2117

22-
class LiteLLMModel(Model):
18+
class LiteLLMModel(OpenAIModel):
2319
"""LiteLLM model provider implementation."""
2420

2521
class LiteLLMConfig(TypedDict, total=False):
@@ -45,7 +41,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
4541
https://github.com/BerriAI/litellm/blob/main/litellm/main.py.
4642
**model_config: Configuration options for the LiteLLM model.
4743
"""
48-
self.config = LiteLLMModel.LiteLLMConfig(**model_config)
44+
self.config = dict(model_config)
4945

5046
logger.debug("config=<%s> | initializing", self.config)
5147

@@ -68,9 +64,11 @@ def get_config(self) -> LiteLLMConfig:
6864
Returns:
6965
The LiteLLM model configuration.
7066
"""
71-
return self.config
67+
return cast(LiteLLMModel.LiteLLMConfig, self.config)
7268

73-
def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
69+
@override
70+
@staticmethod
71+
def format_request_message_content(content: ContentBlock) -> dict[str, Any]:
7472
"""Format a LiteLLM content block.
7573
7674
Args:
@@ -79,28 +77,13 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
7977
Returns:
8078
LiteLLM formatted content block.
8179
"""
82-
if "image" in content:
83-
mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
84-
image_data = content["image"]["source"]["bytes"].decode("utf-8")
85-
return {
86-
"image_url": {
87-
"detail": "auto",
88-
"format": mime_type,
89-
"url": f"data:{mime_type};base64,{image_data}",
90-
},
91-
"type": "image_url",
92-
}
93-
9480
if "reasoningContent" in content:
9581
return {
9682
"signature": content["reasoningContent"]["reasoningText"]["signature"],
9783
"thinking": content["reasoningContent"]["reasoningText"]["text"],
9884
"type": "thinking",
9985
}
10086

101-
if "text" in content:
102-
return {"text": content["text"], "type": "text"}
103-
10487
10000 if "video" in content:
10588
return {
10689
"type": "video_url",
@@ -110,232 +93,4 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
11093
},
11194
}
11295

113-
return {"text": json.dumps(content), "type": "text"}
114-
115-
def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
116-
"""Format a LiteLLM tool call.
117-
118-
Args:
119-
tool_use: Tool use requested by the model.
120-
121-
Returns:
122-
LiteLLM formatted tool call.
123-
"""
124-
return {
125-
"function": {
126-
"arguments": json.dumps(tool_use["input"]),
127-
"name": tool_use["name"],
128-
},
129-
"id": tool_use["toolUseId"],
130-
"type": "function",
131-
}
132-
133-
def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
134-
"""Format a LiteLLM tool message.
135-
136-
Args:
137-
tool_result: Tool result collected from a tool execution.
138-
139-
Returns:
140-
LiteLLM formatted tool message.
141-
"""
142-
return {
143-
"role": "tool",
144-
"tool_call_id": tool_result["toolUseId"],
145-
"content": json.dumps(
146-
{
147-
"content": tool_result["content"],
148-
"status": tool_result["status"],
149-
}
150-
),
151-
}
152-
153-
def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
154-
"""Format a LiteLLM messages array.
155-
156-
Args:
157-
messages: List of message objects to be processed by the model.
158-
system_prompt: System prompt to provide context to the model.
159-
160-
Returns:
161-
A LiteLLM messages array.
162-
"""
163-
formatted_messages: list[dict[str, Any]]
164-
formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
165-
166-
for message in messages:
167-
contents = message["content"]
168-
169-
formatted_contents = [
170-
self._format_request_message_content(content)
171-
for content in contents
172-
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
173-
]
174-
formatted_tool_calls = [
175-
self._format_request_message_tool_call(content["toolUse"])
176-
for content in contents
177-
if "toolUse" in content
178-
]
179-
formatted_tool_messages = [
180-
self._format_request_tool_message(content["toolResult"])
181-
for content in contents
182-
if "toolResult" in content
183-
]
184-
185-
formatted_message = {
186-
"role": message["role"],
187-
"content": formatted_contents,
188-
**({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
189-
}
190-
formatted_messages.append(formatted_message)
191-
formatted_messages.extend(formatted_tool_messages)
192-
193-
return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
194-
195-
@override
196-
def format_request(
197-
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
198-
) -> dict[str, Any]:
199-
"""Format a LiteLLM chat streaming request.
200-
201-
Args:
202-
messages: List of message objects to be processed by the model.
203-
tool_specs: List of tool specifications to make available to the model.
204-
system_prompt: System prompt to provide context to the model.
205-
206-
Returns:
207-
A LiteLLM chat streaming request.
208-
"""
209-
return {
210-
"messages": self._format_request_messages(messages, system_prompt),
211-
"model": self.config["model_id"],
212-
"stream": True,
213-
"stream_options": {"include_usage": True},
214-
"tools": [
215-
{
216-
"type": "function",
217-
"function": {
218-
"name": tool_spec["name"],
219-
"description": tool_spec["description"],
220-
"parameters": tool_spec["inputSchema"]["json"],
221-
},
222-
}
223-
for tool_spec in tool_specs or []
224-
],
225-
**(self.config.get("params") or {}),
226-
}
227-
228-
@override
229-
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
230-
"""Format the LiteLLM response events into standardized message chunks.
231-
232-
Args:
233-
event: A response event from the LiteLLM model.
234-
235-
Returns:
236-
The formatted chunk.
237-
238-
Raises:
239-
RuntimeError: If chunk_type is not recognized.
240-
This error should never be encountered as we control chunk_type in the stream method.
241-
"""
242-
match event["chunk_type"]:
243-
case "message_start":
244-
return {"messageStart": {"role": "assistant"}}
245-
246-
case "content_start":
247-
if event["data_type"] == "tool":
248-
return {
249-
"contentBlockStart": {
250-
"start": {
251-
"toolUse": {
252-
"name": event["data"].function.name,
253-
"toolUseId": event["data"].id,
254-
}
255-
}
256-
}
257-
}
258-
259-
return {"contentBlockStart": {"start": {}}}
260-
261-
case "content_delta":
262-
if event["data_type"] == "tool":
263-
return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}}
264-
265-
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
266-
267-
case "content_stop":
268-
return {"contentBlockStop": {}}
269-
270-
case "message_stop":
271-
match event["data"]:
272-
case "tool_calls":
273-
return {"messageStop": {"stopReason": "tool_use"}}
274-
case "length":
275-
return {"messageStop": {"stopReason": "max_tokens"}}
276-
case _:
277-
return {"messageStop": {"stopReason": "end_turn"}}
278-
279-
case "metadata":
280-
return {
281-
"metadata": {
282-
"usage": {
283-
"inputTokens": event["data"].prompt_tokens,
284-
"outputTokens": event["data"].completion_tokens,
285-
"totalTokens": event["data"].total_tokens,
286-
},
287-
"metrics": {
288-
"latencyMs": 0, # TODO
289-
},
290-
},
291-
}
292-
293-
case _:
294-
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
295-
296-
@override
297-
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
298-
"""Send the request to the LiteLLM model and get the streaming response.
299-
300-
Args:
301-
request: The formatted request to send to the LiteLLM model.
302-
303-
Returns:
304-
An iterable of response events from the LiteLLM model.
305-
"""
306-
response = self.client.chat.completions.create(**request)
307-
308-
yield {"chunk_type": "message_start"}
309-
yield {"chunk_type": "content_start", "data_type": "text"}
310-
311-
tool_calls: dict[int, list[Any]] = {}
312-
313-
for event in response:
314-
choice = event.choices[0]
315-
if choice.finish_reason:
316-
break
317-
318-
if choice.delta.content:
319-
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
320-
321-
for tool_call in choice.delta.tool_calls or []:
322-
tool_calls.setdefault(tool_call.index, []).append(tool_call)
323-
324-
yield {"chunk_type": "content_stop", "data_type": "text"}
325-
326-
for tool_deltas in tool_calls.values():
327-
tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:]
328-
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_start}
329-
330-
for tool_delta in tool_deltas:
331-
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
332-
333-
yield {"chunk_type": "content_stop", "data_type": "tool"}
334-
335-
yield {"chunk_type": "message_stop", "data": choice.finish_reason}
336-
337-
# Skip remaining events as we don't have use for anything except the final usage payload
338-
for event in response:
339-
_ = event
340-
341-
yield {"chunk_type": "metadata", "data": event.usage}
96+
return OpenAIModel.format_request_message_content(content)

0 commit comments

Comments
 (0)
0