8000 Update vllm.py · AhilanPonnusamy/sdk-python@344749a · GitHub
[go: up one dir, main page]

8000
Skip to content

Commit 344749a

Browse files
Update vllm.py
fixed bugs
1 parent c0e2639 commit 344749a

File tree

1 file changed

+97
-104
lines changed

1 file changed

+97
-104
lines changed

src/strands/models/vllm.py

Lines changed: 97 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
"""vLLM model provider.
2-
3-
- Docs: https://github.com/vllm-project/vllm
4-
"""
5-
61
import json
72
import logging
83
from typing import Any, Iterable, Optional
@@ -19,29 +14,7 @@
1914

2015

2116
class VLLMModel(Model):
22-
"""vLLM model provider implementation.
23-
24-
Assumes OpenAI-compatible vLLM server at `http://<host>/v1/completions`.
25-
26-
The implementation handles vLLM-specific features such as:
27-
28-
- Local model invocation
29-
- Streaming responses
30-
- Tool/function calling
31-
"""
32-
3317
class VLLMConfig(TypedDict, total=False):
34-
"""Configuration parameters for vLLM models.
35-
36-
Attributes:
37-
additional_args: Any additional arguments to include in the request.
38-
max_tokens: Maximum number of tokens to generate in the response.
39-
model_id: vLLM model ID (e.g., "meta-llama/Llama-3.2-3B,microsoft/Phi-3-mini-128k-instruct").
40-
options: Additional model parameters (e.g., top_k).
41-
temperature: Controls randomness in generation (higher = more random).
42-
top_p: Controls diversity via nucleus sampling (alternative to temperature).
43-
"""
44-
4518
model_id: str
4619
temperature: Optional[float]
4720
top_p: Optional[float]
@@ -50,32 +23,16 @@ class VLLMConfig(TypedDict, total=False):
5023
additional_args: Optional[dict[str, Any]]
5124

5225
def __init__(self, host: str, **model_config: Unpack[VLLMConfig]) -> None:
53-
"""Initialize provider instance.
54-
55-
Args:
56-
host: The address of the vLLM server hosting the model.
57-
**model_config: Configuration options for the vLLM model.
58-
"""
5926
self.config = VLLMModel.VLLMConfig(**model_config)
6027
self.host = host.rstrip("/")
61-
logger.debug("Initializing vLLM provider with config: %s", self.config)
28+
logger.debug("----Initializing vLLM provider with config: %s", self.config)
6229

6330
@override
6431
def update_config(self, **model_config: Unpack[VLLMConfig]) -> None:
65-
"""Update the vLLM Model configuration with the provided arguments.
66-
67-
Args:
68-
**model_config: Configuration overrides.
69-
"""
7032
self.config.update(model_config)
7133

7234
@override
7335
def get_config(self) -> VLLMConfig:
74-
"""Get the vLLM model configuration.
75-
76-
Returns:
77-
The vLLM model configuration.
78-
"""
7936
return self.config
8037

8138
@override
@@ -85,78 +42,97 @@ def format_request(
8542
tool_specs: Optional[list[ToolSpec]] = None,
8643
system_prompt: Optional[str] = None,
8744
) -> dict[str, Any]:
88-
"""Format an vLLM chat streaming request.
89-
90-
Args:
91-
messages: List of message objects to be processed by the model.
92-
tool_specs: List of tool specifications to make available to the model.
93-
system_prompt: System prompt to provide context to the model.
94-
95-
Returns:
96-
An vLLM chat streaming request.
97-
"""
98-
99-
# Concatenate messages to form a prompt string
100-
prompt_parts = [
101-
f"{msg['role']}: {content['text']}" for msg in messages for content in msg["content"] if "text" in content
102-
]
45+
def format_message(message: dict[str, Any], content: dict[str, Any]) -> dict[str, Any]:
46+
if "text" in content:
47+
return {"role": message["role"], "content": content["text"]}
48+
if "toolUse" in content:
49+
return {
50+
"role": "assistant",
51+
"tool_calls": [
52+
{
53+
"id": content["toolUse"]["toolUseId"],
54+
"type": "function",
55+
"function": {
56+
"name": content["toolUse"]["name"],
57+
"arguments": json.dumps(content["toolUse"]["input"]),
58+
},
59+
}
60+
],
61+
}
62+
if "toolResult" in content:
63+
return {
64+
"role": "tool",
65+
"tool_call_id": content["toolResult"]["toolUseId"],
66+
"content": json.dumps(content["toolResult"]["content"]),
67+
}
68+
return {"role": message["role"], "content": json.dumps(content)}
69+
70+
chat_messages = []
10371
if system_prompt:
104-
prompt_parts.insert(0, f"system: {system_prompt}")
105-
prompt = "\n".join(prompt_parts) + "\nassistant:"
72+
chat_messages.append({"role": "system", "content": system_prompt})
73+
for msg in messages:
74+
for content in msg["content"]:
75+
chat_messages.append(format_message(msg, content))
10676

10777
payload = {
10878
"model": self.config["model_id"],
109-
"prompt": prompt,
79+
"messages": chat_messages,
11080
"temperature": self.config.get("temperature", 0.7),
11181
"top_p": self.config.get("top_p", 1.0),
112-
"max_tokens": self.config.get("max_tokens", 1024),
113-
"stop": self.config.get("stop_sequences"),
114-
"stream": False, # Disable streaming
82+
"max_tokens": self.config.get("max_tokens", 2048),
83+
"stream": True,
11584
}
11685

86+
if self.config.get("stop_sequences"):
87+
payload["stop"] = self.config["stop_sequences"]
88+
89+
if tool_specs:
90+
payload["tools"] = [
91+
{
92+
"type": "function",
93+
"function": {
94+
"name": tool["name"],
95+
"description": tool["description"],
96+
"parameters": tool["inputSchema"]["json"],
97+
},
98+
}
99+
for tool in tool_specs
100+
]
101+
117102
if self.config.get("additional_args"):
118103
payload.update(self.config["additional_args"])
119104

105+
logger.debug("Formatted vLLM Request:\n%s", json.dumps(payload, indent=2))
120106
return payload
121107

122108
@override
123109
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
124-
"""Format the vLLM response events into standardized message chunks.
125-
126-
Args:
127-
event: A response event from the vLLM model.
128-
129-
Returns:
130-
The formatted chunk.
131-
132-
"""
133110
choice = event.get("choices", [{}])[0]
134111

135-
if "text" in choice:
136-
return {"contentBlockDelta": {"delta": {"text": choice["text"]}}}
112+
# Streaming delta (streaming mode)
113+
if "delta" in choice:
114+
delta = choice["delta"]
115+
if "content" in delta:
116+
return {"contentBlockDelta": {"delta": {"text": delta["content"]}}}
117+
if "tool_calls" in delta:
118+
return {"toolCall": delta["tool_calls"][0]}
137119

120+
# Non-streaming response
121+
if "message" in choice:
122+
return {"contentBlockDelta": {"delta": {"text": choice["message"].get("content", "")}}}
123+
124+
# Completion stop
138125
if "finish_reason" in choice:
139126
return {"messageStop": {"stopReason": choice["finish_reason"] or "end_turn"}}
140127

141128
return {}
142129

143130
@override
144131
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
145-
"""Send the request to the vLLM model and get the streaming response.
146-
147-
This method calls the /v1/completions endpoint and returns the stream of response events.
148-
149-
Args:
150-
request: The formatted request to send to the vLLM model.
151-
152-
Returns:
153-
An iterable of response events from the vLLM model.
154-
"""
132+
"""Stream from /v1/chat/completions, print content, and yield chunks including tool calls."""
155133
headers = {"Content-Type": "application/json"}
156-
url = f"{self.host}/v1/completions"
157-
request["stream"] = True # Enable streaming
158-
159-
full_output = ""
134+
url = f"{self.host}/v1/chat/completions"
135+
request["stream"] = True
160136

161137
try:
162138
with requests.post(url, headers=headers, data=json.dumps(request), stream=True) as response:
@@ -179,30 +155,47 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
179155

180156
try:
181157
data = json.loads(line)
182-
choice = data.get("choices", [{}])[0]
183-
text = choice.get("text", "")
184-
finish_reason = choice.get("finish_reason")
158+
delta = data.get("choices", [{}])[0].get("delta", {})
159+
content = delta.get("content", "")
160+
tool_calls = delta.get("tool_calls")
185161

186-
if text:
187-
full_output += text
188-
print(text, end="", flush=True) # Stream to stdout without newline
162+
if content:
163+
print(content, end="", flush=True)
189164
yield {
190165
"chunk_type": "content_delta",
191166
"data_type": "text",
192-
"data": text,
167+
"data": content,
193168
}
194169

195-
if finish_reason:
196-
yield {"chunk_type": "content_stop", "data_type": "text"}
197-
yield {"chunk_type": "message_stop", "data": finish_reason}
198-
break
170+
if tool_calls:
171+
for tool_call in tool_calls:
172+
tool_call_id = tool_call.get("id")
173+
func = tool_call.get("function", {})
174+
tool_name = func.get("name", "")
175+
args_text = func.get("arguments", "")
176+
177+
yield {
178+
"toolCallStart": {
179+
"toolCallId": tool_call_id,
180+
"toolName": tool_name,
181+
"type": "function",
182+
}
183+
}
184+
yield {
185+
"toolCallDelta": {
186+
"toolCallId": tool_call_id,
187+
"delta": {
188+
"toolName": tool_name,
189+
"argsText": args_text,
190+
},
191+
}
192+
}
199193

200194
except json.JSONDecodeError:
201195
logger.warning("Failed to decode streamed line: %s", line)
202196

203-
else:
204-
yield {"chunk_type": "content_stop", "data_type": "text"}
205-
yield {"chunk_type": "message_stop", "data": "end_turn"}
197+
yield {"chunk_type": "content_stop", "data_type": "text"}
198+
yield {"chunk_type": "message_stop", "data": "end_turn"}
206199

207200
except requests.RequestException as e:
208201
logger.error("Request to vLLM failed: %s", str(e))

0 commit comments

Comments
 (0)
0