1
- """vLLM model provider.
2
-
3
- - Docs: https://github.com/vllm-project/vllm
4
- """
5
-
6
1
import json
7
2
import logging
8
3
from typing import Any , Iterable , Optional
19
14
20
15
21
16
class VLLMModel (Model ):
22
- """vLLM model provider implementation.
23
-
24
- Assumes OpenAI-compatible vLLM server at `http://<host>/v1/completions`.
25
-
26
- The implementation handles vLLM-specific features such as:
27
-
28
- - Local model invocation
29
- - Streaming responses
30
- - Tool/function calling
31
- """
32
-
33
17
class VLLMConfig (TypedDict , total = False ):
34
- """Configuration parameters for vLLM models.
35
-
36
- Attributes:
37
- additional_args: Any additional arguments to include in the request.
38
- max_tokens: Maximum number of tokens to generate in the response.
39
- model_id: vLLM model ID (e.g., "meta-llama/Llama-3.2-3B,microsoft/Phi-3-mini-128k-instruct").
40
- options: Additional model parameters (e.g., top_k).
41
- temperature: Controls randomness in generation (higher = more random).
42
- top_p: Controls diversity via nucleus sampling (alternative to temperature).
43
- """
44
-
45
18
model_id : str
46
19
temperature : Optional [float ]
47
20
top_p : Optional [float ]
@@ -50,32 +23,16 @@ class VLLMConfig(TypedDict, total=False):
50
23
additional_args : Optional [dict [str , Any ]]
51
24
52
25
def __init__ (self , host : str , ** model_config : Unpack [VLLMConfig ]) -> None :
53
- """Initialize provider instance.
54
-
55
- Args:
56
- host: The address of the vLLM server hosting the model.
57
- **model_config: Configuration options for the vLLM model.
58
- """
59
26
self .config = VLLMModel .VLLMConfig (** model_config )
60
27
self .host = host .rstrip ("/" )
61
- logger .debug ("Initializing vLLM provider with config: %s" , self .config )
28
+ logger .debug ("---- Initializing vLLM provider with config: %s" , self .config )
62
29
63
30
@override
64
31
def update_config (self , ** model_config : Unpack [VLLMConfig ]) -> None :
65
- """Update the vLLM Model configuration with the provided arguments.
66
-
67
- Args:
68
- **model_config: Configuration overrides.
69
- """
70
32
self .config .update (model_config )
71
33
72
34
@override
73
35
def get_config (self ) -> VLLMConfig :
74
- """Get the vLLM model configuration.
75
-
76
- Returns:
77
- The vLLM model configuration.
78
- """
79
36
return self .config
80
37
81
38
@override
@@ -85,78 +42,97 @@ def format_request(
85
42
tool_specs : Optional [list [ToolSpec ]] = None ,
86
43
system_prompt : Optional [str ] = None ,
87
44
) -> dict [str , Any ]:
88
- """Format an vLLM chat streaming request.
89
-
90
- Args:
91
- messages: List of message objects to be processed by the model.
92
- tool_specs: List of tool specifications to make available to the model.
93
- system_prompt: System prompt to provide context to the model.
94
-
95
- Returns:
96
- An vLLM chat streaming request.
97
- """
98
-
99
- # Concatenate messages to form a prompt string
100
- prompt_parts = [
101
- f"{ msg ['role' ]} : { content ['text' ]} " for msg in messages for content in msg ["content" ] if "text" in content
102
- ]
45
+ def format_message (message : dict [str , Any ], content : dict [str , Any ]) -> dict [str , Any ]:
46
+ if "text" in content :
47
+ return {"role" : message ["role" ], "content" : content ["text" ]}
48
+ if "toolUse" in content :
49
+ return {
50
+ "role" : "assistant" ,
51
+ "tool_calls" : [
52
+ {
53
+ "id" : content ["toolUse" ]["toolUseId" ],
54
+ "type" : "function" ,
55
+ "function" : {
56
+ "name" : content ["toolUse" ]["name" ],
57
+ "arguments" : json .dumps (content ["toolUse" ]["input" ]),
58
+ },
59
+ }
60
+ ],
61
+ }
62
+ if "toolResult" in content :
63
+ return {
64
+ "role" : "tool" ,
65
+ "tool_call_id" : content ["toolResult" ]["toolUseId" ],
66
+ "content" : json .dumps (content ["toolResult" ]["content" ]),
67
+ }
68
+ return {"role" : message ["role" ], "content" : json .dumps (content )}
69
+
70
+ chat_messages = []
103
71
if system_prompt :
104
- prompt_parts .insert (0 , f"system: { system_prompt } " )
105
- prompt = "\n " .join (prompt_parts ) + "\n assistant:"
72
+ chat_messages .append ({"role" : "system" , "content" : system_prompt })
73
+ for msg in messages :
74
+ for content in msg ["content" ]:
75
+ chat_messages .append (format_message (msg , content ))
106
76
107
77
payload = {
108
78
"model" : self .config ["model_id" ],
109
- "prompt " : prompt ,
79
+ "messages " : chat_messages ,
110
80
"temperature" : self .config .get ("temperature" , 0.7 ),
111
81
"top_p" : self .config .get ("top_p" , 1.0 ),
112
- "max_tokens" : self .config .get ("max_tokens" , 1024 ),
113
- "stop" : self .config .get ("stop_sequences" ),
114
- "stream" : False , # Disable streaming
82
+ "max_tokens" : self .config .get ("max_tokens" , 2048 ),
83
+ "stream" : True ,
115
84
}
116
85
86
+ if self .config .get ("stop_sequences" ):
87
+ payload ["stop" ] = self .config ["stop_sequences" ]
88
+
89
+ if tool_specs :
90
+ payload ["tools" ] = [
91
+ {
92
+ "type" : "function" ,
93
+ "function" : {
94
+ "name" : tool ["name" ],
95
+ "description" : tool ["description" ],
96
+ "parameters" : tool ["inputSchema" ]["json" ],
97
+ },
98
+ }
99
+ for tool in tool_specs
100
+ ]
101
+
117
102
if self .config .get ("additional_args" ):
118
103
payload .update (self .config ["additional_args" ])
119
104
105
+ logger .debug ("Formatted vLLM Request:\n %s" , json .dumps (payload , indent = 2 ))
120
106
return payload
121
107
122
108
@override
123
109
def format_chunk (self , event : dict [str , Any ]) -> StreamEvent :
124
- """Format the vLLM response events into standardized message chunks.
125
-
126
- Args:
127
- event: A response event from the vLLM model.
128
-
129
- Returns:
130
- The formatted chunk.
131
-
132
- """
133
110
choice = event .get ("choices" , [{}])[0 ]
134
111
135
- if "text" in choice :
136
- return {"contentBlockDelta" : {"delta" : {"text" : choice ["text" ]}}}
112
+ # Streaming delta (streaming mode)
113
+ if "delta" in choice :
114
+ delta = choice ["delta" ]
115
+ if "content" in delta :
116
+ return {"contentBlockDelta" : {"delta" : {"text" : delta ["content" ]}}}
117
+ if "tool_calls" in delta :
118
+ return {"toolCall" : delta ["tool_calls" ][0 ]}
137
119
120
+ # Non-streaming response
121
+ if "message" in choice :
122
+ return {"contentBlockDelta" : {"delta" : {"text" : choice ["message" ].get ("content" , "" )}}}
123
+
124
+ # Completion stop
138
125
if "finish_reason" in choice :
139
126
return {"messageStop" : {"stopReason" : choice ["finish_reason" ] or "end_turn" }}
140
127
141
128
return {}
142
129
143
130
@override
144
131
def stream (self , request : dict [str , Any ]) -> Iterable [dict [str , Any ]]:
145
- """Send the request to the vLLM model and get the streaming response.
146
-
147
- This method calls the /v1/completions endpoint and returns the stream of response events.
148
-
149
- Args:
150
- request: The formatted request to send to the vLLM model.
151
-
152
- Returns:
153
- An iterable of response events from the vLLM model.
154
- """
132
+ """Stream from /v1/chat/completions, print content, and yield chunks including tool calls."""
155
133
headers = {"Content-Type" : "application/json" }
156
- url = f"{ self .host } /v1/completions"
157
- request ["stream" ] = True # Enable streaming
158
-
159
- full_output = ""
134
+ url = f"{ self .host } /v1/chat/completions"
135
+ request ["stream" ] = True
160
136
161
137
try :
162
138
with requests .post (url , headers = headers , data = json .dumps (request ), stream = True ) as response :
@@ -179,30 +155,47 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
179
155
180
156
try :
181
157
data = json .loads (line )
182
- choice = data .get ("choices" , [{}])[0 ]
183
- text = choice .get ("text " , "" )
184
- finish_reason = choice .get ("finish_reason " )
158
+ delta = data .get ("choices" , [{}])[0 ]. get ( "delta" , {})
159
+ content = delta .get ("content " , "" )
160
+ tool_calls = delta .get ("tool_calls " )
185
161
186
- if text :
187
- full_output += text
188
- print (text , end = "" , flush = True ) # Stream to stdout without newline
162
+ if content :
163
+ print (content , end = "" , flush = True )
189
164
yield {
190
165
"chunk_type" : "content_delta" ,
191
166
"data_type" : "text" ,
192
- "data" : text ,
167
+ "data" : content ,
193
168
}
194
169
195
- if finish_reason :
196
- yield {"chunk_type" : "content_stop" , "data_type" : "text" }
197
- yield {"chunk_type" : "message_stop" , "data" : finish_reason }
198
- break
170
+ if tool_calls :
171
+ for tool_call in tool_calls :
172
+ tool_call_id = tool_call .get ("id" )
173
+ func = tool_call .get ("function" , {})
174
+ tool_name = func .get ("name" , "" )
175
+ args_text = func .get ("arguments" , "" )
176
+
177
+ yield {
178
+ "toolCallStart" : {
179
+ "toolCallId" : tool_call_id ,
180
+ "toolName" : tool_name ,
181
+ "type" : "function" ,
182
+ }
183
+ }
184
+ yield {
185
+ "toolCallDelta" : {
186
+ "toolCallId" : tool_call_id ,
187
+ "delta" : {
188
+ "toolName" : tool_name ,
189
+ "argsText" : args_text ,
190
+ },
191
+ }
192
+ }
199
193
200
194
except json .JSONDecodeError :
201
195
logger .warning ("Failed to decode streamed line: %s" , line )
202
196
203
- else :
204
- yield {"chunk_type" : "content_stop" , "data_type" : "text" }
205
- yield {"chunk_type" : "message_stop" , "data" : "end_turn" }
197
+ yield {"chunk_type" : "content_stop" , "data_type" : "text" }
198
+ yield {"chunk_type" : "message_stop" , "data" : "end_turn" }
206
199
207
200
except requests .RequestException as e :
208
201
logger .error ("Request to vLLM failed: %s" , str (e ))
0 commit comments