8000 feature: models - openai (#65) · B-00/sdk-python@6dda2d8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6dda2d8

Browse files
authored
feature: models - openai (strands-agents#65)
1 parent 5f4b68a commit 6dda2d8

File tree

13 files changed

+1024
-772
lines changed

13 files changed

+1024
-772
lines changed

pyproject.toml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ dev = [
5454
"commitizen>=4.4.0,<5.0.0",
5555
"hatch>=1.0.0,<2.0.0",
5656
"moto>=5.1.0,<6.0.0",
57-
"mypy>=0.981,<1.0.0",
57+
"mypy>=1.15.0,<2.0.0",
5858
"pre-commit>=3.2.0,<4.2.0",
5959
"pytest>=8.0.0,<9.0.0",
6060
"pytest-asyncio>=0.26.0,<0.27.0",
@@ -69,15 +69,18 @@ docs = [
6969
litellm = [
7070
"litellm>=1.69.0,<2.0.0",
7171
]
72+
llamaapi = [
73+
"llama-api-client>=0.1.0,<1.0.0",
74+
]
7275
ollama = [
7376
"ollama>=0.4.8,<1.0.0",
7477
]
75-
llamaapi = [
76-
"llama-api-client>=0.1.0,<1.0.0",
78+
openai = [
79+
"openai>=1.68.0,<2.0.0",
7780
]
7881

7982
[tool.hatch.envs.hatch-static-analysis]
80-
features = ["anthropic", "litellm", "llamaapi", "ollama"]
83+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
8184
dependencies = [
8285
"mypy>=1.15.0,<2.0.0",
8386
"ruff>=0.11.6,<0.12.0",
@@ -100,7 +103,7 @@ lint-fix = [
100103
]
101104

102105
[tool.hatch.envs.hatch-test]
103-
features = ["anthropic", "litellm", "llamaapi", "ollama"]
106+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
104107
extra-dependencies = [
105108
"moto>=5.1.0,<6.0.0",
106109
"pytest>=8.0.0,<9.0.0",

src/strands/models/litellm.py

Lines changed: 10 additions & 255 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,19 @@
33
- Docs: https://docs.litellm.ai/
44
"""
55

6-
import json
76
import logging
8-
import mimetypes
9-
from typing import Any, Iterable, Optional, TypedDict
7+
from typing import Any, Optional, TypedDict, cast
108

119
import litellm
1210
from typing_extensions import Unpack, override
1311

14-
from ..types.content import ContentBlock, Messages
15-
from ..types.models import Model
16-
from ..types.streaming import StreamEvent
17-
from ..types.tools import ToolResult, ToolSpec, ToolUse
12+
from ..types.content import ContentBlock
13+
from .openai import OpenAIModel
1814

1915
logger = logging.getLogger(__name__)
2016

2117

22-
class LiteLLMModel(Model):
18+
class LiteLLMModel(OpenAIModel):
2319
"""LiteLLM model provider implementation."""
2420

2521
class LiteLLMConfig(TypedDict, total=False):
@@ -45,7 +41,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
4541
https://github.com/BerriAI/litellm/blob/main/litellm/main.py.
4642
**model_config: Configuration options for the LiteLLM model.
4743
"""
48-
self.config = LiteLLMModel.LiteLLMConfig(**model_config)
44+
self.config = dict(model_config)
4945

5046
logger.debug("config=<%s> | initializing", self.config)
5147

@@ -68,9 +64,11 @@ def get_config(self) -> LiteLLMConfig:
6864
Returns:
6965
The LiteLLM model configuration.
7066
"""
71-
return self.config
67+
return cast(LiteLLMModel.LiteLLMConfig, self.config)
7268

73-
def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
69+
@override
70+
@staticmethod
71+
def format_request_message_content(content: ContentBlock) -> dict[str, Any]:
7472
"""Format a LiteLLM content block.
7573
7674
Args:
@@ -79,28 +77,13 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
7977
Returns:
8078
LiteLLM formatted content block.
8179
"""
82-
if "image" in content:
83-
mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
84-
image_data = content["image"]["source"]["bytes"].decode("utf-8")
85-
return {
86-
"image_url": {
87-
"detail": "auto",
88-
"format": mime_type,
89-
"url": f"data:{mime_type};base64,{image_data}",
90-
},
91-
"type": "image_url",
92-
}
93-
9480
if "reasoningContent" in content:
9581
return {
9682
"signature": content["reasoningContent"]["reasoningText"]["signature"],
9783
"thinking": content["reasoningContent"]["reasoningText"]["text"],
9884
"type": "thinking",
9985
}
10086

101-
if "text" in content:
102-
return {"text": content["text"], "type": "text"}
103-
10487
if "video" in content:
10588
return {
10689
"type": "video_url",
@@ -110,232 +93,4 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
11093
},
11194
}
11295

113-
return {"text": json.dumps(content), "type": "text"}
114-
115-
def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
116-
"""Format a LiteLLM tool call.
117-
118-
Args:
119-
tool_use: Tool use requested by the model.
120-
121-
Returns:
122-
LiteLLM formatted tool call.
123-
"""
124-
return {
125-
"function": {
126-
"arguments": json.dumps(tool_use["input"]),
127-
"name": tool_use["name"],
128-
},
129-
"id": tool_use["toolUseId"],
130-
"type": "function",
131-
}
132-
133-
def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
134-
"""Format a LiteLLM tool message.
135-
136-
Args:
137-
tool_result: Tool result collected from a tool execution.
138-
139-
Returns:
140-
LiteLLM formatted tool message.
141-
"""
142-
return {
143-
"role": "tool",
144-
"tool_call_id": tool_result["toolUseId"],
145-
"content": json.dumps(
146-
{
147-
"content": tool_result["content"],
148-
"status": tool_result["status"],
149-
}
150-
),
151-
}
152-
153-
def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
154-
"""Format a LiteLLM messages array.
155-
156-
Args:
157-
messages: List of message objects to be processed by the model.
158-
system_prompt: System prompt to provide context to the model.
159-
160-
Returns:
161-
A LiteLLM messages array.
162-
"""
163-
formatted_messages: list[dict[str, Any]]
164-
formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
165-
166-
for message in messages:
167-
contents = message["content"]
168-
169-
formatted_contents = [
170-
self._format_request_message_content(content)
171-
for content in contents
172-
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
173-
]
174-
formatted_tool_calls = [
175-
self._format_request_message_tool_call(content["toolUse"])
176-
for content in contents
177-
if "toolUse" in content
178-
]
179-
formatted_tool_messages = [
180-
self._format_request_tool_message(content["toolResult"])
181-
for content in contents
182-
if "toolResult" in content
183-
]
184-
185-
formatted_message = {
186-
"role": message["role"],
187-
"content": formatted_contents,
188-
**({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
189-
}
190-
formatted_messages.append(formatted_message)
191-
formatted_messages.extend(formatted_tool_messages)
192-
193-
return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
194-
195-
@override
196-
def format_request(
197-
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
198-
) -> dict[str, Any]:
199-
"""Format a LiteLLM chat streaming request.
200-
201-
Args:
202-
messages: List of message objects to be processed by the model.
203-
tool_specs: List of tool specifications to make available to the model.
204-
system_prompt: System prompt to provide context to the model.
205-
206-
Returns:
207-
A LiteLLM chat streaming request.
208-
"""
209-
return {
210-
"messages": self._format_request_messages(messages, system_prompt),
211-
"model": self.config["model_id"],
212-
"stream": True,
213-
"stream_options": {"include_usage": True},
214-
"tools": [
215-
{
216-
"type": "function",
217-
"function": {
218-
"name": tool_spec["name"],
219-
"description": tool_spec["description"],
220-
"parameters": tool_spec["inputSchema"]["json"],
221-
},
222-
}
223-
for tool_spec in tool_specs or []
224-
],
225-
**(self.config.get("params") or {}),
226-
}
227-
228-
@override
229-
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
230-
"""Format the LiteLLM response events into standardized message chunks.
231-
232-
Args:
233-
event: A response event from the LiteLLM model.
234-
235-
Returns:
236-
The formatted chunk.
237-
238-
Raises:
239-
RuntimeError: If chunk_type is not recognized.
240-
This error should never be encountered as we control chunk_type in the stream method.
241-
"""
242-
match event["chunk_type"]:
243-
case "message_start":
244-
return {"messageStart": {"role": "assistant"}}
245-
246-
case "content_start":
247-
if event["data_type"] == "tool":
248-
return {
249-
"contentBlockStart": {
250-
"start": {
251-
"toolUse": {
252-
"name": event["data"].function.name,
253-
"toolUseId": event["data"].id,
254-
}
255-
}
256-
}
257-
}
258-
259-
return {"contentBlockStart": {"start": {}}}
260-
261-
case "content_delta":
262-
if event["data_type"] == "tool":
263-
return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}}
264-
265-
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
266-
267-
case "content_stop":
268-
return {"contentBlockStop": {}}
269-
270-
case "message_stop":
271-
match event["data"]:
272-
case "tool_calls":
273-
return {"messageStop": {"stopReason": "tool_use"}}
274-
case "length":
275-
return {"messageStop": {"stopReason": "max_tokens"}}
276-
case _:
277-
return {"messageStop": {"stopReason": "end_turn"}}
278-
279-
case "metadata":
280-
return {
281-
"metadata": {
282-
"usage": {
283-
"inputTokens": event["data"].prompt_tokens,
284-
"outputTokens": event["data"].completion_tokens,
285-
"totalTokens": event["data"].total_tokens,
286-
},
287-
"metrics": {
288-
"latencyMs": 0, # TODO
289-
},
290-
},
291-
}
292-
293-
case _:
294-
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
295-
296-
@override
297-
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
298-
"""Send the request to the LiteLLM model and get the streaming response.
299-
300-
Args:
301-
request: The formatted request to send to the LiteLLM model.
302-
303-
Returns:
304-
An iterable of response events from the LiteLLM model.
305-
"""
306-
respo 10936 nse = self.client.chat.completions.create(**request)
307-
308-
yield {"chunk_type": "message_start"}
309-
yield {"chunk_type": "content_start", "data_type": "text"}
310-
311-
tool_calls: dict[int, list[Any]] = {}
312-
313-
for event in response:
314-
choice = event.choices[0]
315-
if choice.finish_reason:
316-
break
317-
318-
if choice.delta.content:
319-
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
320-
321-
for tool_call in choice.delta.tool_calls or []:
322-
tool_calls.setdefault(tool_call.index, []).append(tool_call)
323-
324-
yield {"chunk_type": "content_stop", "data_type": "text"}
325-
326-
for tool_deltas in tool_calls.values():
327-
tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:]
328-
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_start}
329-
330-
for tool_delta in tool_deltas:
331-
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
332-
333-
yield {"chunk_type": "content_stop", "data_type": "tool"}
334-
335-
yield {"chunk_type": "message_stop", "data": choice.finish_reason}
336-
337-
# Skip remaining events as we don't have use for anything except the final usage payload
338-
for event in response:
339-
_ = event
340-
341-
yield {"chunk_type": "metadata", "data": event.usage}
96+
return OpenAIModel.format_request_message_content(content)

0 commit comments

Comments
 (0)
0