8000 Refactor engine 2505 by grimoire · Pull Request #3541 · InternLM/lmdeploy · GitHub
[go: up one dir, main page]

Skip to content

Refactor engine 2505 #3541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class PytorchEngineConfig:
enable_prefix_caching (bool): Enable token match and sharing caches.
device_type (str): The inference device type, options ['cuda']
eager_mode (bool): Enable "eager" mode or not
custom_module_map (Dict): nn module map customized by users. Once
custom_module_map (str): nn module map customized by users. Once
provided, the original nn modules of the model will be
substituted by the mapping ones
download_dir (str): Directory to download and load the weights,
Expand Down Expand Up @@ -331,7 +331,7 @@ class PytorchEngineConfig:
enable_prefix_caching: bool = False
device_type: str = 'cuda'
eager_mode: bool = False
custom_module_map: Dict[str, str] = None
custom_module_map: str = None
download_dir: str = None
revision: str = None
quant_policy: Literal[0, 4, 8] = 0
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run_chat(model_path: str,
from lmdeploy import pipeline

if gen_config is None:
gen_config = GenerationConfig(do_sample=True)
gen_config = GenerationConfig(max_new_tokens=4096, do_sample=True)

adapter_name = None
if engine_config.adapters is not None:
Expand Down Expand Up @@ -131,7 +131,7 @@ def main(model_path: str,
if adapter is not None:
adapters = dict(default=adapter)
engine_config = PytorchEngineConfig(tp=tp, adapters=adapters)
gen_config = GenerationConfig(max_new_tokens=512,
gen_config = GenerationConfig(max_new_tokens=4096,
top_k=top_k,
top_p=top_p,
temperature=temperature,
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/config.py
8000
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from typing import Any, Dict, List, Literal
from typing import Any, List, Literal

import torch

Expand Down Expand Up @@ -143,7 +143,7 @@ class ModelConfig:
vocab_size: int = 40000
hf_config: Any = None
cogvlm_style: bool = False
custom_module_map: Dict[str, setattr] = None
custom_module_map: str = None
use_flash_mla: bool = False

def get_head_size(self):
Expand Down
9 changes: 7 additions & 2 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,9 @@ def _response(self, resp: Response, resp_type: ResponseType, data: Any = None, e
def _get_max_session_len(self):
"""get max session len."""
session_len = self.scheduler_config.max_session_len
max_tokens = (self.cache_config.num_gpu_blocks * self.cache_config.block_size)
block_size = self.cache_config.block_size
# leave one block to avoid over allocate
max_tokens = (self.cache_config.num_gpu_blocks * block_size - block_size)
window_size = self.cache_config.window_size
if window_size > 0 and window_size <= max_tokens:
max_tokens = (1 << 63) - 1
Expand Down Expand Up @@ -600,7 +602,10 @@ def model_config(self) -> ModelConfig:

@property
def gpu_count(self):
return self.tp * self.dp
dist_config = self.dist_config
if dist_config.dp > 1:
return 1
return max(dist_config.tp, dist_config.ep)

@property
def torch_int_dtype(self):
Expand Down
232 changes: 79 additions & 153 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@
logger = get_logger('lmdeploy')


def msg_with_rank(rank: int, msg: str):
"""return message with rank."""
return f'rank[{rank}] - {msg}'


def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict):
"""perform cache swapping."""
issued_cache_op = False
Expand Down Expand Up @@ -125,30 +120,60 @@ async def async_wait(self, timeout: float = 0.001):
SwapMap = Dict[int, int]


class AutoModelAgent:
"""Base model agent."""
class BaseModelAgent:
"""Base model agent.

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
tokenizer: Any,
dist_ctx: DistContext,
device_ctx: DeviceContext,
):
load model on local gpu

Args:
model_path (str): The hugging face model path.
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache info.
trust_remote_code (bool): Trust remote code
"""

def __init__(self,
model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
backend_config: BackendConfig,
misc_config: MiscConfig,
tokenizer: Any,
dist_ctx: DistContext,
device_ctx: DeviceContext,
adapters: Dict[str, str] = None):
# input args
self.model_path = model_path
self.model_config = model_config
self.cache_config = cache_config
self.backend_config = backend_config
self.misc_config = misc_config
self.tokenizer = tokenizer
self.dist_ctx = dist_ctx
self.device_ctx = device_ctx
self.adapters = adapters

# async fields
self._in_que = None
self._out_que = None
self._background_task = None

# cuda streams
self.stream = torch.cuda.Stream()
self.out_stream = torch.cuda.Stream()

self.dist_ctx = dist_ctx
self.device_ctx = device_ctx
# dist info
self.device = 'cuda'
self.rank = dist_ctx.rank
self.tp_rank = dist_ctx.tp_rank

# model and cache
self.patched_model = None
self.cache_engine = None

def warp_msg(self, msg: str):
"""return message with rank."""
return f'rank[{self.rank}] - {msg}'

@contextmanager
def all_context(self):
Expand All @@ -157,42 +182,51 @@ def all_context(self):
with device_mgr.context(self.device_ctx), dist_mgr.context(self.dist_ctx):
yield

def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
raise NotImplementedError('NotImplemented.')

async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
"""model forward.

Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
raise NotImplementedError('Not implemented.')

def get_logits(self, hidden_states: torch.Tensor):
"""get logits of model output."""
raise NotImplementedError('Not implemented.')

def get_input_processor(self):
"""get input processor."""
raise NotImplementedError('Not implemented.')
def _build_model(self):
"""build patched model."""
model_path = self.model_path
adapters = self.adapters
device = self.device
custom_module_map = self.model_config.custom_module_map
if custom_module_map is not None:
update_custom_module_map(custom_module_map)
logger.debug(self.warp_msg('build model.'))
patched_model = build_patched_model(self.model_config, device=device)
logger.debug(self.warp_msg('loading weights.'))
if not self.misc_config.empty_init:
load_model_weights(patched_model, model_path, device=device)
if adapters is not None:
logger.debug(self.warp_msg('loading adapters.'))
add_adapters(patched_model, adapters, dtype=self.model_config.dtype, device=device)
self.patched_model = patched_model

def build_model(self):
"""build model."""
raise NotImplementedError('Not implemented.')
"""build model api."""
with self.all_context():
self._build_model()

def build_graph_runner(self):
"""build graph runner."""
raise NotImplementedError('Not Implemented.')
with self.all_context():
backend = get_backend()
self.patched_model = backend.build_graph_runner(self.patched_model,
model_config=self.model_config,
cache_config=self.cache_config,
backend_config=self.backend_config,
device=self.device)

def build_cache_engine(self):
"""build cache engine."""
raise NotImplementedError('Not Implemented.')
with self.all_context():
dist_ctx = self.dist_ctx
attn_dist_cfg = dist_ctx.dist_config.attn_config
tp = attn_dist_cfg.tp

def release(self):
"""release."""
raise NotImplementedError('Not Implemented.')
self.cache_engine = CacheEngine(self.cache_config,
self.model_config,
rank=self.rank,
tp_rank=self.tp_rank,
world_size=tp)

def set_cache_config(self, cache_config: CacheConfig):
"""set all cache config."""
Expand Down Expand Up @@ -620,103 +654,6 @@ async def get_output_async(self):
out['logits'] = out['logits'].cpu()
return out


class BaseModelAgent(AutoModelAgent):
"""Base model agent.

load model on local gpu

Args:
model_path (str): The hugging face model path.
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache info.
trust_remote_code (bool): Trust remote code
"""

def __init__(self,
model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
backend_config: BackendConfig,
misc_config: MiscConfig,
tokenizer: Any,
dist_ctx: DistContext,
device_ctx: DeviceContext,
adapters: Dict[str, str] = None):
super().__init__(
model_config=model_config,
cache_config=cache_config,
tokenizer=tokenizer,
dist_ctx=dist_ctx,
device_ctx=device_ctx,
)
device = 'cuda'
self.backend_config = backend_config
self.misc_config = misc_config
rank = dist_ctx.rank

self.model_path = model_path
self.adapters = adapters
self.device = device
self.rank = rank

tp_rank = dist_ctx.tp_rank
tp = dist_ctx.tp
world_size = dist_ctx.world_size
self.tp = tp
self.world_size = world_size
self.tp_rank = tp_rank

self.patched_model = None
self.cache_engine = None

def _build_model(self):
"""build patched model."""
model_path = self.model_path
adapters = self.adapters
device = self.device
rank = self.rank
custom_module_map = self.model_config.custom_module_map
if custom_module_map is not None:
update_custom_module_map(custom_module_map)
logger.debug(msg_with_rank(rank, 'build model.'))
patched_model = build_patched_model(self.model_config, device=device)
logger.debug(msg_with_rank(rank, 'loading weights.'))
if not self.misc_config.empty_init:
load_model_weights(patched_model, model_path, device=device)
if adapters is not None:
logger.debug(msg_with_rank(rank, 'loading adapters.'))
add_adapters(patched_model, adapters, dtype=self.model_config.dtype, device=device)
self.patched_model = patched_model

def build_model(self):
"""build model api."""
with self.all_context():
self._build_model()

def build_graph_runner(self):
"""build graph runner."""
with self.all_context():
backend = get_backend()
self.patched_model = backend.build_graph_runner(self.patched_model,
model_config=self.model_config,
cache_config=self.cache_config,
backend_config=self.backend_config,
device=self.device)

def build_cache_engine(self):
"""build cache engine."""
with self.all_context():
dist_ctx = self.dist_ctx
attn_dist_cfg = dist_ctx.dist_config.attn_config
tp = attn_dist_cfg.tp

self.cache_engine = CacheEngine(self.cache_config,
self.model_config,
rank=self.rank,
tp_rank=self.tp_rank,
world_size=tp)

def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
output = model_forward(
Expand Down Expand Up @@ -797,18 +734,7 @@ def build_model_agent(model_path: str,
dist_ctx: DistContext = None,
device_ctx: DeviceContext = None,
adapters: Dict[str, str] = None):
"""create model agent.

Args:
model_path (str): the path of the input model
cache_config (CacheConfig): config of kv cache
backend_config (BackendConfig): config of backend devices
trust_remote_code (bool): To use the remote modeling code or not
adapters (Dict): lora adapters
tp (int): the number of devices to be used in tensor parallelism
dtype (str): the data type of model weights and activations
custom_module_map (str): customized nn module map
"""
"""create model agent."""
if device_ctx is None:
device_mgr = get_device_manager()
device_ctx = device_mgr.current_context()
Expand Down
Loading
Loading
0