@@ -79,6 +79,7 @@ def add_post_route(
79
79
input_model ,
80
80
output_model ,
81
81
input_doc_list_model = None ,
82
+ output_doc_list_model = None ,
82
83
):
83
84
from docarray .base_doc .docarray_response import DocArrayResponse
84
85
@@ -128,7 +129,7 @@ async def process(body) -> output_model:
128
129
if status .code == jina_pb2 .StatusProto .ERROR :
129
130
raise HTTPException (status_code = 499 , detail = status .description )
130
131
else :
131
- return { " predictions" : resp .docs }
132
+ return VertexAIResponse ( predictions = output_model ( data = resp .docs , parameters = resp . parameters ))
132
133
133
134
@app .api_route (** app_kwargs )
134
135
async def post (request : Request ):
@@ -151,6 +152,7 @@ async def post(request: Request):
151
152
for endpoint , input_output_map in request_models_map .items ():
152
153
if endpoint != '_jina_dry_run_' :
153
154
input_doc_model = input_output_map ['input' ]['model' ]
155
+ output_doc_model = input_output_map ['output' ]['model' ]
154
156
parameters_model = input_output_map ['parameters' ]['model' ] or Optional [Dict ]
155
157
default_parameters = (
156
158
... if input_output_map ['parameters' ]['model' ] else None
@@ -165,11 +167,19 @@ async def post(request: Request):
165
167
__config__ = _config ,
166
168
)
167
169
170
+ endpoint_output_model = pydantic .create_model (
171
+ f'{ endpoint .strip ("/" )} _output_model' ,
172
+ data = (Union [List [output_doc_model ], output_doc_model ], ...),
173
+ parameters = (Optional [Dict ], None ),
174
+ __config__ = _config ,
175
+ )
176
+
168
177
add_post_route (
169
178
endpoint ,
170
179
input_model = endpoint_input_model ,
171
- output_model = VertexAIResponse ,
180
+ output_model =
52B8
endpoint_output_model ,
172
181
input_doc_list_model = input_doc_model ,
182
+ output_doc_list_model = VertexAIResponse ,
173
183
)
174
184
175
185
from jina .serve .runtimes .gateway .health_model import JinaHealthModel
0 commit comments