8000 enlarge dummy input · grimoire/lmdeploy@ffa60a3 · GitHub
[go: up one dir, main page]

Skip to content

Commit ffa60a3

Browse files
committed
enlarge dummy input
1 parent a75d567 commit ffa60a3

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

lmdeploy/pytorch/engine/engine.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def __init__(self,
374374
self.scheduler_config = scheduler_config
375375
self.cache_config = cache_config
376376
self.backend_config = backend_config
377+
self.dist_config = dist_config
377378
self.max_session_len = self._get_max_session_len()
378379

379380
self.req_manager = self._bind_request_manager()
@@ -851,9 +852,14 @@ def __make_dummy_inputs():
851852
"""make dummy inputs."""
852853
logger.info(f'make dummy forward inputs: prefill={prefill}.')
853854
num_loops = 1 if prefill else prefill_interval
855+
856+
batch_size = 2 if self.dist_config.enable_microbatch else 1
857+
batch_size = min(self.cache_config.max_batches, batch_size)
854858
return dict(
855859
running=[],
856-
inputs=ModelInputs.make_dummy(1, is_decoding=not prefill, vocab_size=self.model_config.vocab_size),
860+
inputs=ModelInputs.make_dummy(batch_size,
861+
is_decoding=not prefill,
862+
vocab_size=self.model_config.vocab_size),
857863
swap_in_map=dict(),
858864
swap_out_map=dict(),
859865
loop_count=num_loops,
@@ -1030,7 +1036,6 @@ async def _async_loop_main(
10301036
next_running = None
10311037

10321038
while True:
1033-
logger.info('begin loop')
10341039
if next_running is None:
10351040
await has_runable_event.wait()
10361041
scheduler.collect_migration_done()
@@ -1043,16 +1048,12 @@ async def _async_loop_main(
10431048
if idx >= num_loops - 1:
10441049
scheduler.collect_migration_done()
10451050
forward_inputs, next_running = await inputs_maker.prefetch_next_inputs()
1046-
logger.info('inputs forwarding done')
10471051
out = await self.executor.get_output_async()
1048-
logger.info('get_output_async done')
10491052
if len(out) > 0:
10501053
step_outputs = self._make_infer_outputs(**out, running=running)
10511054
resp_que.put_nowait(step_outputs)
1052-
logger.info('send response done')
10531055
scheduler.unlock_running(running)
10541056
has_runable_event.set()
1055-
logger.info('end loop')
10561057

10571058
@staticmethod
10581059
def _add_loop_tasks_done_callback(tasks: List[asyncio.Task]):

0 commit comments

Comments
 (0)
0