1
+ """vLLM model provider.
2
+
3
+ - Docs: https://docs.vllm.ai/en/latest/index.html
4
+ """
1
5
import json
2
6
import logging
7
+ import re
8
+ from collections import namedtuple
3
9
from typing import Any , Iterable , Optional
4
10
5
11
import requests
14
20
15
21
16
22
class VLLMModel (Model ):
23
+ """vLLM model provider implementation for OpenAI compatible /v1/chat/completions endpoint."""
24
+
17
25
class VLLMConfig (TypedDict , total = False ):
26
+ """Configuration options for vLLM models.
27
+
28
+ Attributes:
29
+ model_id: Model ID (e.g., "Qwen/Qwen3-4B").
30
+ temperature: Optional[float]
31
+ top_p: Optional[float]
32
+ max_tokens: Optional[int]
33
+ stop_sequences: Optional[list[str]]
34
+ additional_args: Optional[dict[str, Any]]
35
+ """
36
+
18
37
model_id : str
19
38
temperature : Optional [float ]
20
39
top_p : Optional [float ]
@@ -23,16 +42,32 @@ class VLLMConfig(TypedDict, total=False):
23
42
additional_args : Optional [dict [str , Any ]]
24
43
25
44
def __init__ (self , host : str , ** model_config : Unpack [VLLMConfig ]) -> None :
45
+ """Initialize provider instance.
46
+
47
+ Args:
48
+ host: Host and port of the vLLM Inference Server
49
+ **model_config: Configuration options for the LiteLLM model.
50
+ """
26
51
self .config = VLLMModel .VLLMConfig (** model_config )
27
52
self .host = host .rstrip ("/" )
28
- logger .debug ("---- Initializing vLLM provider with config: %s" , self .config )
53
+ logger .debug ("Initializing vLLM provider with config: %s" , self .config )
29
54
30
55
@override
31
56
def update_config (self , ** model_config : Unpack [VLLMConfig ]) -> None :
57
+ """Update the vLLM model configuration with the provided arguments.
58
+
59
+ Args:
60
+ **model_config: Configuration overrides.
61
+ """
32
62
self .config .update (model_config )
33
63
34
64
@override
35
65
def get_config (self ) -> VLLMConfig :
66
+ """Get the vLLM model configuration.
67
+
68
+ Returns:
69
+ The vLLM model configuration.
70
+ """
36
71
return self .config
37
72
38
73
@override
@@ -42,9 +77,20 @@ def format_request(
42
77
tool_specs : Optional [list [ToolSpec ]] = None ,
43
78
system_prompt : Optional [str ] = None ,
44
79
) -> dict [str , Any ]:
45
- def format_message (message : dict [str , Any ], content : dict [str , Any ]) -> dict [str , Any ]:
80
+ """Format a vLLM chat streaming request.
81
+
82
+ Args:
83
+ messages: List of message objects to be processed by the model.
84
+ tool_specs: List of tool specifications to make available to the model.
85
+ system_prompt: System prompt to provide context to the model.
86
+
87
+ Returns:
88
+ A vLLM chat streaming request.
89
+ """
90
+
91
+ def format_message (msg : dict [str , Any ], content : dict [str , Any ]) -> dict [str , Any ]:
46
92
if "text" in content :
47
- return {"role" : message ["role" ], "content" : content ["text" ]}
93
+ return {"role" : msg ["role" ], "content" : content ["text" ]}
48
94
if "toolUse" in content :
49
95
return {
50
96
"role" : "assistant" ,
@@ -65,7 +111,7 @@ def format_message(message: dict[str, Any], content: dict[str, Any]) -> dict[str
65
111
"tool_call_id" : content ["toolResult" ]["toolUseId" ],
66
112
"content" : json .dumps (content ["toolResult" ]["content" ]),
67
113
}
68
- return {"role" : message ["role" ], "content" : json .dumps (content )}
114
+ return {"role" : msg ["role" ], "content" : json .dumps (content )}
69
115
70
116
chat_messages = []
71
117
if system_prompt :
@@ -107,32 +153,103 @@ def format_message(message: dict[str, Any], content: dict[str, Any]) -> dict[str
107
153
108
154
@override
109
155
def format_chunk (self , event : dict [str , Any ]) -> StreamEvent :
110
- choice = event . get ( "choices" , [{}])[ 0 ]
156
+ """Format the vLLM response events into standardized message chunks.
111
157
112
- # Streaming delta (streaming mode)
113
- if "delta" in choice :
114
- delta = choice ["delta" ]
115
- if "content" in delta :
116
- return {"contentBlockDelta" : {"delta" : {"text" : delta ["content" ]}}}
117
- if "tool_calls" in delta :
118
- return {"toolCall" : delta ["tool_calls" ][0 ]}
158
+ Args:
159
+ event: A response event from the vLLM model.
119
160
120
- # Non-streaming response
121
- if "message" in choice :
122
- return {"contentBlockDelta" : {"delta" : {"text" : choice ["message" ].get ("content" , "" )}}}
161
+ Returns:
162
+ The formatted chunk.
123
163
124
- # Completion stop
125
- if "finish_reason" in choice :
126
- return {"messageStop" : {"stopReason" : choice ["finish_reason" ] or "end_turn" }}
164
+ Raises:
165
+ RuntimeError: If chunk_type is not recognized.
166
+ This error should never be encountered as we control chunk_type in the stream method.
167
+ """
168
+ from collections import namedtuple
127
169
128
- return {}
170
+ Function = namedtuple ("Function" , ["name" , "arguments" ])
171
+
172
+ if event .get ("chunk_type" ) == "message_start" :
173
+ return {"messageStart" : {"role" : "assistant" }}
174
+
175
+ if event .get ("chunk_type" ) == "content_start" :
176
+ if event ["data_type" ] == "text" :
177
+ return {"contentBlockStart" : {"start" : {}}}
178
+
179
+ tool : Function = event ["data" ]
180
+ return {
181
+ "contentBlockStart" : {
182
+ "start" : {
183
+ "toolUse" : {
184
+ "name" : tool .name ,
185
+ "toolUseId" : tool .name ,
186
+ }
187
+ }
188
+ }
189
+ }
190
+
191
+ if
57AE
event .get ("chunk_type" ) == "content_delta" :
192
+ if event ["data_type" ] == "text" :
193
+ return {"contentBlockDelta" : {"delta" : {"text" : event ["data" ]}}}
194
+
195
+ tool : Function = event ["data" ]
196
+ return {
197
+ "contentBlockDelta" : {
198
+ "delta" : {
199
+ "toolUse" : {
200
+ "input" : json .dumps (tool .arguments ) # This is already a dict
201
+ }
202
+ }
203
+ }
204
+ }
205
+
206
+ if event .get ("chunk_type" ) == "content_stop" :
207
+ return {"contentBlockStop" : {}}
208
+
209
+ if event .get ("chunk_type" ) == "message_stop" :
210
+ reason = event ["data" ]
211
+ if reason == "tool_use" :
212
+ return {"messageStop" : {"stopReason" : "tool_use" }}
213
+ elif reason == "length" :
214
+ return {"messageStop" : {"stopReason" : "max_tokens" }}
215
+ else :
216
+ return {"messageStop" : {"stopReason" : "end_turn" }}
217
+
218
+ if event .get ("chunk_type" ) == "metadata" :
219
+ usage = event .get ("data" , {})
220
+ return {
221
+ "metadata" : {
222
+ "usage" : {
223
+ "inputTokens" : usage .get ("prompt_eval_count" , 0 ),
224
+ "outputTokens" : usage .get ("eval_count" , 0 ),
225
+ "totalTokens" : usage .get ("prompt_eval_count" , 0 ) + usage .get ("eval_count" , 0 ),
226
+ },
227
+ "metrics" : {
228
+ "latencyMs" : usage .get ("total_duration" , 0 ) / 1e6 ,
229
+ },
230
+ }
231
+ }
232
+
233
+ raise RuntimeError (f"chunk_type=<{ event .get ('chunk_type' )} > | unknown type" )
129
234
130
235
@override
131
236
def stream (self , request : dict [str , Any ]) -> Iterable [dict [str , Any ]]:
132
- """Stream from /v1/chat/completions, print content, and yield chunks including tool calls."""
237
+ """Send the request to the vLLM model and get the streaming response.
238
+
239
+ Args:
240
+ request: The formatted request to send to the vLLM model.
241
+
242
+ Returns:
243
+ An iterable of response events from the vLLM model.
244
+ """
245
+
246
+ Function = namedtuple ("Function" , ["name" , "arguments" ])
247
+
133
248
headers = {"Content-Type" : "application/json" }
134
249
url = f"{ self .host } /v1/chat/completions"
135
- request ["stream" ] = True
250
+
251
+ accumulated_content = []
252
+ tool_requested = False
136
253
137
254
try :
138
255
with requests .post (url , headers = headers , data = json .dumps (request ), stream = True ) as response :
@@ -144,59 +261,50 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
144
261
yield {"chunk_type" : "content_start" , "data_type" : "text" }
145
262
146
263
for line in response .iter_lines (decode_unicode = True ):
147
- if not line :
264
+ if not line or not line . startswith ( "data: " ) :
148
265
continue
266
+ line = line [len ("data: " ) :].strip ()
149
267
150
- if line .startswith ("data: " ):
151
- line = line [len ("data: " ) :]
152
-
153
- if line .strip () == "[DONE]" :
268
+ if line == "[DONE]" :
154
269
break
155
270
156
271
try :
157
- data = json .loads (line )
158
- delta = data .get ("choices" , [{}])[0 ].get ("delta" , {})
159
- content = delta .get ("content" , "" )
160
- tool_calls = delta .get ("tool_calls" )
161
-
162
- if content :
163
- print (content , end = "" , flush = True )
164
- yield {
165
- "chunk_type" : "content_delta" ,
166
- "data_type" : "text" ,
167
- "data" : content ,
168
- }
169
-
170
- if tool_calls :
171
- for tool_call in tool_calls :
172
- tool_call_id = tool_call .get ("id" )
173
- func = tool_call .get ("function" , {})
174
- tool_name = func .get ("name" , "" )
175
- args_text = func .get ("arguments" , "" )
176
-
177
- yield {
178
- "toolCallStart" : {
179
- "toolCallId" : tool_call_id ,
180
- "toolName" : tool_name ,
181
- "type" : "function" ,
182
- }
183
- }
184
- yield {
185
- "toolCallDelta" : {
186
- "toolCallId" : tool_call_id ,
187
- "delta" : {
188
- "toolName" : tool_name ,
189
- "argsText" : args_text ,
190
- },
191
- }
192
- }
272
+ event = json .loads (line )
273
+ choices = event .get ("choices" , [])
274
+ if choices :
275
+ delta = choices [0 ].get ("delta" , {})
276
+ content = delta .get ("content" )
277
+ if content :
278
+ accumulated_content .append (content )
279
+
280
+ yield {"chunk_type" : "content_delta" , "data_type" : "text" , "data" : content or "" }
193
281
194
282
except json .JSONDecodeError :
195
- logger .warning ("Failed to decode streamed line: %s" , line )
283
+ logger .warning ("Failed to parse line: %s" , line )
284
+ continue
196
285
197
286
yield {"chunk_type" : "content_stop" , "data_type" : "text" }
198
- yield {"chunk_type" : "message_stop" , "data" : "end_turn" }
287
+
288
+ full_content = "" .join (accumulated_content )
289
+
290
+ tool_call_blocks = re .findall (r"<tool_call>(.*?)</tool_call>" , full_content , re .DOTALL )
291
+ for idx , block in enumerate (tool_call_blocks ):
292
+ try :
293
+ tool_call_data = json .loads (block .strip ())
294
+ func = Function (name = tool_call_data ["name" ], arguments = tool_call_data .get ("arguments" , {}))
295
+ func_str = f"function=Function(name='{ func .name } ', arguments={ func .arguments } )"
296
+
297
+ yield {"chunk_type" : "content_start" , "data_type" : "tool" , "data" : func }
298
+ yield {"chunk_type" : "content_delta" , "data_type" : "tool" , "data" : func }
299
+ yield {"chunk_type" : "content_stop" , "data_type" : "tool" , "data" : func }
300
+ tool_requested = True
301
+
302
+ except json .JSONDecodeError :
303
+ logger .warning (f"Failed to parse tool_call block #{ idx } : { block } " )
304
+ continue
305
+
306
+ yield {"chunk_type" : "message_stop" , "data" : "tool_use" if tool_requested else "end_turn" }
199
307
200
308
except requests .RequestException as e :
201
- logger .error ("Request to vLLM failed: %s" , str (e ))
309
+ logger .error ("Streaming request failed: %s" , str (e ))
202
310
raise Exception ("Failed to reach vLLM server" ) from e
0 commit comments