5
5
6
6
import json
7
7
import logging
8
- from typing import Any , Iterable , Optional , Union
8
+ from typing import Any , Iterable , Optional , cast
9
9
10
10
from ollama import Client as OllamaClient
11
11
from typing_extensions import TypedDict , Unpack , override
12
12
13
- from ..types .content import ContentBlock , Message , Messages
14
- from ..types .media import DocumentContent , ImageContent
13
+ from ..types .content import ContentBlock , Messages
15
14
from ..types .models import Model
16
15
from ..types .streaming import StopReason , StreamEvent
17
16
from ..types .tools import ToolSpec
@@ -92,35 +91,31 @@ def get_config(self) -> OllamaConfig:
92
91
"""
93
92
return self .config
94
93
95
- @override
96
- def format_request (
97
- self , messages : Messages , tool_specs : Optional [list [ToolSpec ]] = None , system_prompt : Optional [str ] = None
98
- ) -> dict [str , Any ]:
99
- """Format an Ollama chat streaming request.
94
+ def _format_request_message_contents (self , role : str , content : ContentBlock ) -> list [dict [str , Any ]]:
95
+ """Format Ollama compatible message contents.
96
+
97
+ Ollama doesn't support an array of contents, so we must flatten everything into separate message blocks.
100
98
101
99
Args:
102
- messages: List of message objects to be processed by the model.
103
- tool_specs: List of tool specifications to make available to the model.
104
- system_prompt: System prompt to provide context to the model.
100
+ role: E.g., user.
101
+ content: Content block to format.
105
102
106
103
Returns:
107
- An Ollama chat streaming request .
104
+ Ollama formatted message contents .
108
105
109
106
Raises:
110
- TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible
111
- format.
107
+ TypeError: If the content block type cannot be converted to an Ollama-compatible format.
112
108
"""
109
+ if "text" in content :
110
+ return [{"role" : role , "content" : content ["text" ]}]
113
111
114
- def format_message (message : Message , content : ContentBlock ) -> dict [str , Any ]:
115
- if "text" in content :
116
- return {"role" : message ["role" ], "content" : content ["text" ]}
112
+ if "image" in content :
113
+ return [{"role" : role , "images" : [content ["image" ]["source" ]["bytes" ]]}]
117
114
118
- if "image" in content :
119
- return {"role" : message ["role" ], "images" : [content ["image" ]["source" ]["bytes" ]]}
120
-
121
- if "toolUse" in content :
122
- return {
123
- "role" : "assistant" ,
115
+ if "toolUse" in content :
116
+ return [
117
+ {
118
+ "role" : role ,
124
119
"tool_calls" : [
125
120
{
126
121
"function" : {
@@ -130,45 +125,63 @@ def format_message(message: Message, content: ContentBlock) -> dict[str, Any]:
130
125
}
131
126
],
132
127
}
128
+ ]
129
+
130
+ if "toolResult" in content :
131
+ return [
132
+ formatted_tool_result_content
133
+ for tool_result_content in content ["toolResult" ]["content" ]
134
+ for formatted_tool_result_content in self ._format_request_message_contents (
135
+ "tool" ,
136
+ (
137
+ {"text" : json .dumps (tool_result_content ["json" ])}
138
+ if "json" in tool_result_content
139
+ else cast (ContentBlock , tool_result_content )
140
+ ),
141
+ )
142
+ ]
133
143
134
- if "toolResult" in content :
135
- result_content : Union [str , ImageContent , DocumentContent , Any ] = None
136
- result_images = []
137
- for tool_result_content in content ["toolResult" ]["content" ]:
138
- if "text" in tool_result_content :
139
- result_content = tool_result_content ["text" ]
140
- elif "json" in tool_result_content :
141
- result_content = tool_result_content ["json" ]
142
- elif "image" in tool_result_content :
143
- result_content = "see images"
144
- result_images .append (tool_result_content ["image" ]["source" ]["bytes" ])
145
- else :
146
- result_content = content ["toolResult" ]["content" ]
144
+ raise TypeError (f"content_type=<{ next (iter (content ))} > | unsupported type" )
147
145
148
- return {
149
- "role" : "tool" ,
150
- "content" : json .dumps (
151
- {
152
- "name" : content ["toolResult" ]["toolUseId" ],
153
- "result" : result_content ,
154
- "status" : content ["toolResult" ]["status" ],
155
- }
156
- ),
157
- ** ({"images" : result_images } if result_images else {}),
158
- }
146
+ def _format_request_messages (self , messages : Messages , system_prompt : Optional [str ] = None ) -> list [dict [str , Any ]]:
147
+ """Format an Ollama compatible messages array.
159
148
160
- raise TypeError (f"content_type=<{ next (iter (content ))} > | unsupported type" )
149
+ Args:
150
+ messages: List of message objects to be processed by the model.
151
+ system_prompt: System prompt to provide context to the model.
161
152
162
- def format_messages () -> list [dict [str , Any ]]:
163
- return [format_message (message , content ) for message in messages for content in message ["content" ]]
153
+ Returns:
154
+ An Ollama compatible messages array.
155
+ """
156
+ system_message = [{"role" : "system" , "content" : system_prompt }] if system_prompt else []
164
157
165
- formatted_messages = format_messages ()
158
+ return system_message + [
159
+ formatted_message
160
+ for message in messages
161
+ for content in message ["content" ]
162
+ for formatted_message in self ._format_request_message_contents (message ["role" ], content )
163
+ ]
166
164
165
+ @override
166
+ def format_request (
167
+ self , messages : Messages , tool_specs : Optional [list [ToolSpec ]] = None , system_prompt : Optional [str ] = None
168
+ ) -> dict [str , Any ]:
169
+ """Format an Ollama chat streaming request.
170
+
171
+ Args:
172
+ messages: List of message objects to be processed by the model.
173
+ tool_specs: List of tool specifications to make available to the model.
174
+ system_prompt: System prompt to provide context to the model.
175
+
176
+ Returns:
177
+ An Ollama chat streaming request.
178
+
179
+ Raises:
180
+ TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible
181
+ format.
182
+ """
167
183
return {
168
- "messages" : [
169
- * ([{"role" : "system" , "content" : system_prompt }] if system_prompt else []),
170
- * formatted_messages ,
171
- ],
184
+ "messages" : self ._format_request_messages (messages , system_prompt ),
172
185
"model" : self .config ["model_id" ],
173
186
"options" : {
174
187
** (self .config .get ("options" ) or {}),
@@ -217,52 +230,54 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
217
230
RuntimeError: If chunk_type is not recognized.
218
231
This error should never be encountered as we control chunk_type in the stream method.
219
232
"""
220
- if event ["chunk_type" ] == "message_start" :
221
- return {"messageStart" : {"role" : "assistant" }}
222
-
223
- if event ["chunk_type" ] == "content_start" :
224
- if event ["data_type" ] == "text" :
225
- return {"contentBlockStart" : {"start" : {}}}
226
-
227
- tool_name = event ["data" ].function .name
228
- return {"contentBlockStart" : {"start" : {"toolUse" : {"name" : tool_name , "toolUseId" : tool_name }}}}
229
-
230
- if event ["chunk_type" ] == "content_delta" :
231
- if event ["data_type" ] == "text" :
232
- return {"contentBlockDelta" : {"delta" : {"text" : event ["data" ]}}}
233
-
234
- tool_arguments = event ["data" ].function .arguments
235
- return {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : json .dumps (tool_arguments )}}}}
236
-
237
- if event ["chunk_type" ] == "content_stop" :
238
- return {"contentBlockStop" : {}}
239
-
240
- if event ["chunk_type" ] == "message_stop" :
241
- reason : StopReason
242
- if event ["data" ] == "tool_use" :
243
- reason = "tool_use"
244
- elif event ["data" ] == "length" :
245
- reason = "max_tokens"
246
- else :
247
- reason = "end_turn"
248
-
249
- return {"messageStop" : {"stopReason" : reason }}
250
-
251
- if event ["chunk_type" ] == "metadata" :
252
- return {
253
- "metadata" : {
254
- "usage" : {
255
- "inputTokens" : event ["data" ].eval_count ,
256
- "outputTokens" : event ["data" ].prompt_eval_count ,
257
- "totalTokens" : event ["data" ].eval_count + event ["data" ].prompt_eval_count ,
258
- },
259
- "metrics" : {
260
- "latencyMs" : event ["data" ].total_duration / 1e6 ,
233
+ match event ["chunk_type" ]:
234
+ case "message_start" :
235
+ return {"messageStart" : {"role" : "assistant" }}
236
+
237
+ case "content_start" :
238
+ if event ["data_type" ] == "text" :
239
+ return {"contentBlockStart" : {"start" : {}}}
240
+
241
+ tool_name = event ["data" ].function .name
242
+ return {"contentBlockStart" : {"start" : {"toolUse" : {"name" : tool_name , "toolUseId" : tool_name }}}}
243
+
244
+ case "content_delta" :
245
+ if event ["data_type" ] == "text" :
246
+ return {"contentBlockDelta" : {"delta" : {"text" : event ["data" ]}}}
247
+
248
+ tool_arguments = event ["data" ].function .arguments
249
+ return {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : json .dumps (tool_arguments )}}}}
250
+
251
+ case "content_stop" :
252
+ return {"contentBlockStop" : {}}
253
+
254
+ case "message_stop" :
255
+ reason : StopReason
256
+ if event ["data" ] == "tool_use" :
257
+ reason = "tool_use"
258
+ elif event ["data" ] == "length" :
259
+ reason = "max_tokens"
260
+ else :
261
+ reason = "end_turn"
262
+
263
+ return {"messageStop" : {"stopReason" : reason }}
264
+
265
+ case "metadata" :
266
+ return {
267
+ "metadata" : {
268
+ "usage" : {
269
+ "inputTokens" : event ["data" ].eval_count ,
270
+ "outputTokens" : event ["data" ].prompt_eval_count ,
271
+ "totalTokens" : event ["data" ].eval_count + event ["data" ].prompt_eval_count ,
272
+ },
273
+ "metrics" : {
274
+ "latencyMs" : event ["data" ].total_duration / 1e6 ,
275
+ },
261
276
},
262
- },
263
- }
277
+ }
264
278
265
- raise RuntimeError (f"chunk_type=<{ event ['chunk_type' ]} | unknown type" )
279
+ case _:
280
+ raise RuntimeError (f"chunk_type=<{ event ['chunk_type' ]} | unknown type" )
266
281
267
282
@override
268
283
def stream (self , request : dict [str , Any ]) -> Iterable [dict [str , Any ]]:
0 commit comments