@@ -374,6 +374,7 @@ def __init__(self,
374
374
self .scheduler_config = scheduler_config
375
375
self .cache_config = cache_config
376
376
self .backend_config = backend_config
377
+ self .dist_config = dist_config
377
378
self .max_session_len = self ._get_max_session_len ()
378
379
379
380
self .req_manager = self ._bind_request_manager ()
@@ -851,9 +852,14 @@ def __make_dummy_inputs():
851
852
"""make dummy inputs."""
852
853
logger .info (f'make dummy forward inputs: prefill={ prefill } .' )
853
854
num_loops = 1 if prefill else prefill_interval
855
+
856
+ batch_size = 2 if self .dist_config .enable_microbatch else 1
857
+ batch_size = min (self .cache_config .max_batches , batch_size )
854
858
return dict (
855
859
running = [],
856
- inputs = ModelInputs .make_dummy (1 , is_decoding = not prefill , vocab_size = self .model_config .vocab_size ),
860
+ inputs = ModelInputs .make_dummy (batch_size ,
861
+ is_decoding = not prefill ,
862
+ vocab_size = self .model_config .vocab_size ),
857
863
swap_in_map = dict (),
858
864
swap_out_map = dict (),
859
865
loop_count = num_loops ,
@@ -1030,7 +1036,6 @@ async def _async_loop_main(
1030
1036
next_running = None
1031
1037
1032
1038
while True :
1033
- logger .info ('begin loop' )
1034
1039
if next_running is None :
1035
1040
await has_runable_event .wait ()
1036
1041
scheduler .collect_migration_done ()
@@ -1043,16 +1048,12 @@ async def _async_loop_main(
1043
1048
if idx >= num_loops - 1 :
1044
1049
scheduler .collect_migration_done ()
1045
1050
forward_inputs , next_running = await inputs_maker .prefetch_next_inputs ()
1046
- logger .info ('inputs forwarding done' )
1047
1051
out = await self .executor .get_output_async ()
1048
- logger .info ('get_output_async done' )
1049
1052
if len (out ) > 0 :
1050
1053
step_outputs = self ._make_infer_outputs (** out , running = running )
1051
1054
resp_que .put_nowait (step_outputs )
1052
- logger .info ('send response done' )
1053
1055
scheduler .unlock_running (running )
1054
1056
has_runable_event .set ()
1055
- logger .info ('end loop' )
1056
1057
1057
1058
@staticmethod
1058
1059
def _add_loop_tasks_done_callback (tasks : List [asyncio .Task ]):
0 commit comments