8000 Update vllm.py · strands-agents/sdk-python@7e85e87 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7e85e87

Browse files
Update vllm.py
Fixed tool usage error fix and updated comments
1 parent d31649e commit 7e85e87

File tree

1 file changed

+173
-65
lines changed

1 file changed

+173
-65
lines changed

src/strands/models/vllm.py

Lines changed: 173 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
"""vLLM model provider.
2+
3+
- Docs: https://docs.vllm.ai/en/latest/index.html
4+
"""
15
import json
26
import logging
7+
import re
8+
from collections import namedtuple
39
from typing import Any, Iterable, Optional
410

511
import requests
@@ -14,7 +20,20 @@
1420

1521

1622
class VLLMModel(Model):
23+
"""vLLM model provider implementation for OpenAI compatible /v1/chat/completions endpoint."""
24+
1725
class VLLMConfig(TypedDict, total=False):
26+
"""Configuration options for vLLM models.
27+
28+
Attributes:
29+
model_id: Model ID (e.g., "Qwen/Qwen3-4B").
30+
temperature: Optional[float]
31+
top_p: Optional[float]
32+
max_tokens: Optional[int]
33+
stop_sequences: Optional[list[str]]
34+
additional_args: Optional[dict[str, Any]]
35+
"""
36+
1837
model_id: str
1938
temperature: Optional[float]
2039
top_p: Optional[float]
@@ -23,16 +42,32 @@ class VLLMConfig(TypedDict, total=False):
2342
additional_args: Optional[dict[str, Any]]
2443

2544
def __init__(self, host: str, **model_config: Unpack[VLLMConfig]) -> None:
45+
"""Initialize provider instance.
46+
47+
Args:
48+
host: Host and port of the vLLM Inference Server
49+
**model_config: Configuration options for the LiteLLM model.
50+
"""
2651
self.config = VLLMModel.VLLMConfig(**model_config)
2752
self.host = host.rstrip("/")
28-
logger.debug("----Initializing vLLM provider with config: %s", self.config)
53+
logger.debug("Initializing vLLM provider with config: %s", self.config)
2954

3055
@override
3156
def update_config(self, **model_config: Unpack[VLLMConfig]) -> None:
57+
"""Update the vLLM model configuration with the provided arguments.
58+
59+
Args:
60+
**model_config: Configuration overrides.
61+
"""
3262
self.config.update(model_config)
3363

3464
@override
3565
def get_config(self) -> VLLMConfig:
66+
"""Get the vLLM model configuration.
67+
68+
Returns:
69+
The vLLM model configuration.
70+
"""
3671
return self.config
3772

3873
@override
@@ -42,9 +77,20 @@ def format_request(
4277
tool_specs: Optional[list[ToolSpec]] = None,
4378
system_prompt: Optional[str] = None,
4479
) -> dict[str, Any]:
45-
def format_message(message: dict[str, Any], content: dict[str, Any]) -> dict[str, Any]:
80+
"""Format a vLLM chat streaming request.
81+
82+
Args:
83+
messages: List of message objects to be processed by the model.
84+
tool_specs: List of tool specifications to make available to the model.
85+
system_prompt: System prompt to provide context to the model.
86+
87+
Returns:
88+
A vLLM chat streaming request.
89+
"""
90+
91+
def format_message(msg: dict[str, Any], content: dict[str, Any]) -> dict[str, Any]:
4692
if "text" in content:
47-
return {"role": message["role"], "content": content["text"]}
93+
return {"role": msg["role"], "content": content["text"]}
4894
if "toolUse" in content:
4995
return {
5096
"role": "assistant",
@@ -65,7 +111,7 @@ def format_message(message: dict[str, Any], content: dict[str, Any]) -> dict[str
65111
"tool_call_id": content["toolResult"]["toolUseId"],
66112
"content": json.dumps(content["toolResult"]["content"]),
67113
}
68-
return {"role": message["role"], "content": json.dumps(content)}
114+
return {"role": msg["role"], "content": json.dumps(content)}
69115

70116
chat_messages = []
71117
if system_prompt:
@@ -107,32 +153,103 @@ def format_message(message: dict[str, Any], content: dict[str, Any]) -> dict[str
107153

108154
@override
109155
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
110-
choice = event.get("choices", [{}])[0]
156+
"""Format the vLLM response events into standardized message chunks.
111157
112-
# Streaming delta (streaming mode)
113-
if "delta" in choice:
114-
delta = choice["delta"]
115-
if "content" in delta:
116-
return {"contentBlockDelta": {"delta": {"text": delta["content"]}}}
117-
if "tool_calls" in delta:
118-
return {"toolCall": delta["tool_calls"][0]}
158+
Args:
159+
event: A response event from the vLLM model.
119160
120-
# Non-streaming response
121-
if "message" in choice:
122-
return {"contentBlockDelta": {"delta": {"text": choice["message"].get("content", "")}}}
161+
Returns:
162+
The formatted chunk.
123163
124-
# Completion stop
125-
if "finish_reason" in choice:
126-
return {"messageStop": {"stopReason": choice["finish_reason"] or "end_turn"}}
164+
Raises:
165+
RuntimeError: If chunk_type is not recognized.
166+
This error should never be encountered as we control chunk_type in the stream method.
167+
"""
168+
from collections import namedtuple
127169

128-
return {}
170+
Function = namedtuple("Function", ["name", "arguments"])
171+
172+
if event.get("chunk_type") == "message_start":
173+
return {"messageStart": {"role": "assistant"}}
174+
175+
if event.get("chunk_type") == "content_start":
176+
if event["data_type"] == "text":
177+
return {"contentBlockStart": {"start": {}}}
178+
179+
tool: Function = event["data"]
180+
return {
181+
"contentBlockStart": {
182+
"start": {
183+
"toolUse": {
184+
"name": tool.name,
185+
"toolUseId": tool.name,
186+
}
187+
}
188+
}
189+
}
190+
191+
if 57AE event.get("chunk_type") == "content_delta":
192+
if event["data_type"] == "text":
193+
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
194+
195+
tool: Function = event["data"]
196+
return {
197+
"contentBlockDelta": {
198+
"delta": {
199+
"toolUse": {
200+
"input": json.dumps(tool.arguments) # This is already a dict
201+
}
202+
}
203+
}
204+
}
205+
206+
if event.get("chunk_type") == "content_stop":
207+
return {"contentBlockStop": {}}
208+
209+
if event.get("chunk_type") == "message_stop":
210+
reason = event["data"]
211+
if reason == "tool_use":
212+
return {"messageStop": {"stopReason": "tool_use"}}
213+
elif reason == "length":
214+
return {"messageStop": {"stopReason": "max_tokens"}}
215+
else:
216+
return {"messageStop": {"stopReason": "end_turn"}}
217+
218+
if event.get("chunk_type") == "metadata":
219+
usage = event.get("data", {})
220+
return {
221+
"metadata": {
222+
"usage": {
223+
"inputTokens": usage.get("prompt_eval_count", 0),
224+
"outputTokens": usage.get("eval_count", 0),
225+
"totalTokens": usage.get("prompt_eval_count", 0) + usage.get("eval_count", 0),
226+
},
227+
"metrics": {
228+
"latencyMs": usage.get("total_duration", 0) / 1e6,
229+
},
230+
}
231+
}
232+
233+
raise RuntimeError(f"chunk_type=<{event.get('chunk_type')}> | unknown type")
129234

130235
@override
131236
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
132-
"""Stream from /v1/chat/completions, print content, and yield chunks including tool calls."""
237+
"""Send the request to the vLLM model and get the streaming response.
238+
239+
Args:
240+
request: The formatted request to send to the vLLM model.
241+
242+
Returns:
243+
An iterable of response events from the vLLM model.
244+
"""
245+
246+
Function = namedtuple("Function", ["name", "arguments"])
247+
133248
headers = {"Content-Type": "application/json"}
134249
url = f"{self.host}/v1/chat/completions"
135-
request["stream"] = True
250+
251+
accumulated_content = []
252+
tool_requested = False
136253

137254
try:
138255
with requests.post(url, headers=headers, data=json.dumps(request), stream=True) as response:
@@ -144,59 +261,50 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
144261
yield {"chunk_type": "content_start", "data_type": "text"}
145262

146263
for line in response.iter_lines(decode_unicode=True):
147-
if not line:
264+
if not line or not line.startswith("data: "):
148265
continue
266+
line = line[len("data: ") :].strip()
149267

150-
if line.startswith("data: "):
151-
line = line[len("data: ") :]
152-
153-
if line.strip() == "[DONE]":
268+
if line == "[DONE]":
154269
break
155270

156271
try:
157-
data = json.loads(line)
158-
delta = data.get("choices", [{}])[0].get("delta", {})
159-
content = delta.get("content", "")
160-
tool_calls = delta.get("tool_calls")
161-
162-
if content:
163-
print(content, end="", flush=True)
164-
yield {
165-
"chunk_type": "content_delta",
166-
"data_type": "text",
167-
"data": content,
168-
}
169-
170-
if tool_calls:
171-
for tool_call in tool_calls:
172-
tool_call_id = tool_call.get("id")
173-
func = tool_call.get("function", {})
174-
tool_name = func.get("name", "")
175-
args_text = func.get("arguments", "")
176-
177-
yield {
178-
"toolCallStart": {
179-
"toolCallId": tool_call_id,
180-
"toolName": tool_name,
181-
"type": "function",
182-
}
183-
}
184-
yield {
185-
"toolCallDelta": {
186-
"toolCallId": tool_call_id,
187-
"delta": {
188-
"toolName": tool_name,
189-
"argsText": args_text,
190-
},
191-
}
192-
}
272+
event = json.loads(line)
273+
choices = event.get("choices", [])
274+
if choices:
275+
delta = choices[0].get("delta", {})
276+
content = delta.get("content")
277+
if content:
278+
accumulated_content.append(content)
279+
280+
yield {"chunk_type": "content_delta", "data_type": "text", "data": content or ""}
193281

194282
except json.JSONDecodeError:
195-
logger.warning("Failed to decode streamed line: %s", line)
283+
logger.warning("Failed to parse line: %s", line)
284+
continue
196285

197286
yield {"chunk_type": "content_stop", "data_type": "text"}
198-
yield {"chunk_type": "message_stop", "data": "end_turn"}
287+
288+
full_content = "".join(accumulated_content)
289+
290+
tool_call_blocks = re.findall(r"<tool_call>(.*?)</tool_call>", full_content, re.DOTALL)
291+
for idx, block in enumerate(tool_call_blocks):
292+
try:
293+
tool_call_data = json.loads(block.strip())
294+
func = Function(name=tool_call_data["name"], arguments=tool_call_data.get("arguments", {}))
295+
func_str = f"function=Function(name='{func.name}', arguments={func.arguments})"
296+
297+
yield {"chunk_type": "content_start", "data_type": "tool", "data": func}
298+
yield {"chunk_type": "content_delta", "data_type": "tool", "data": func}
299+
yield {"chunk_type": "content_stop", "data_type": "tool", "data": func}
300+
tool_requested = True
301+
302+
except json.JSONDecodeError:
303+
logger.warning(f"Failed to parse tool_call block #{idx}: {block}")
304+
continue
305+
306+
yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else "end_turn"}
199307

200308
except requests.RequestException as e:
201-
logger.error("Request to vLLM failed: %s", str(e))
309+
logger.error("Streaming request failed: %s", str(e))
202310
raise Exception("Failed to reach vLLM server") from e

0 commit comments

Comments
 (0)
0