8000 Optimize InternViT · InternLM/lmdeploy@e057f52 · GitHub
[go: up one dir, main page]

Skip to c 10000 ontent

Commit e057f52

Browse files
authored
Optimize InternViT
Optimize internvit
2 parents 2b641b8 + 0222d47 commit e057f52

File tree

4 files changed

+123
-9
lines changed

4 files changed

+123
-9
lines changed

lmdeploy/pytorch/backends/cuda/graph_runner.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from lmdeploy.pytorch.backends.selector import get_backend
67
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
78
from lmdeploy.pytorch.model_inputs import StepContext
89
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
@@ -116,6 +117,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf
116117

117118
self.graph_pool_handle = torch.cuda.graph_pool_handle()
118119
self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict()
120+
self.has_try_compile_model: bool = False
119121

120122
def check_enable_graph(self):
121123
"""check enable graph."""
@@ -124,6 +126,16 @@ def check_enable_graph(self):
124126

125127
return getattr(self.model, 'support_cuda_graph', _false)
126128

129+
def _try_compile_model_once(self):
130+
if self.has_try_compile_model:
131+
return
132+
133+
if hasattr(self.model, 'compile_model'):
134+
method = getattr(self.model, 'compile_model')
135+
method()
136+
137+
self.has_try_compile_model = True
138+
127139
def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List,
128140
attn_metadata: Any, inputs_embeds: torch.Tensor, **kwargs):
129141
"""get graph key."""
@@ -135,6 +147,9 @@ def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, pas
135147

136148
def __call__(self, **kwargs):
137149
"""call."""
150+
if not self.backend_config.eager_mode and get_backend().get_name() == 'cuda':
151+
self._try_compile_model_once()
152+
138153
enable_graph = self.enable_graph(**kwargs)
139154

140155
if not enable_graph:

lmdeploy/pytorch/models/internvl.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44

55
import torch
66
import torch.nn.functional as F
7+
from packaging import version
78
from torch import nn
89
from transformers.configuration_utils import PretrainedConfig
910

11+
from lmdeploy.pytorch.distributed import get_world_rank
1012
from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
1113
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
14+
from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch
1215
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
1316
from lmdeploy.pytorch.nn import LayerNorm, RMSNorm
1417
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear
@@ -205,15 +208,22 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
205208
self.ls1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))
206209
self.ls2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))
207210

208-
def forward(
209-
self,
210-
hidden_states: torch.Tensor,
211-
):
212-
"""forward."""
213-
hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
211+
@enable_micro_batch(param_name='hidden_states', index=0)
212+
def _attn(self, hidden_states):
213+
hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states[0].dtype)) * self.ls1
214+
return hidden_states
214215

216+
@enable_micro_batch(param_name='hidden_states', index=0)
217+
def _mlp(self, hidden_states):
215218
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2
219+
return hidden_states
216220

221+
def forward(
222+
self,
223+
hidden_states,
224+
):
225+
hidden_states = self._attn(hidden_states)
226+
hidden_states = self._mlp(hidden_states)
217227
return hidden_states
218228

219229

@@ -306,6 +316,33 @@ def __init__(self,
306316

307317
self.input_processor = InternVLInputProcessor(self.config, dtype)
308318

319+
self.compile_vit = False
320+
321+
def compile_model(self):
322+
torch_version = version.parse(torch.__version__)
323+
if torch_version < version.parse('2.5.0'):
324+
return
325+
326+
world_size, _ = get_world_rank()
327+
if torch_version >= version.parse('2.6.0') and world_size > 1:
328+
torch._inductor.config.reorder_for_compute_comm_overlap = True
329+
if isinstance(self.vision_model, InternVisionModel):
330+
self.vision_model.encoder.forward = split_batch(self.vision_model.encoder.forward,
331+
'inputs_embeds',
332+
index=0)
333+
334+
self.extract_feature = torch.compile(self.extract_feature, mode='max-autotune')
335+
self.compile_vit = True
336+
self.has_compiled_vit = False
337+
338+
def _mark_dynamic_once(self, pixel_values, dims):
339+
"""call torch._dynamo.mark_dynamic to avoid recompile."""
340+
if not self.compile_vit or self.has_compiled_vit or pixel_values is None:
341+
return
342+
343+
torch._dynamo.mark_dynamic(pixel_values, dims)
344+
self.has_compiled_vit = True
345+
309346
def pixel_shuffle(self, x, scale_factor=0.5):
310347
n, w, h, c = x.size()
311348
# N, W, H, C --> N, W, H * scale, C // scale
@@ -350,6 +387,7 @@ def forward(
350387
):
351388
if inputs_embeds is None and pixel_values is not None:
352389
# extract feature
390+
self._mark_dynamic_once(pixel_values, [0])
353391
vit_embeds = self.extract_feature(pixel_values)
354392
lang_embeds = self.language_model.get_input_embeddings()(input_ids)
355393
lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import functools
3+
4+
import torch
5+
6+
7+
def enable_micro_batch(param_name, index=-1):
8+
"""Decorator factory to enable micro-batch computation."""
9+
10+
def decorator(func):
11+
12+
@functools.wraps(func)
13+
def wrapper(self, *args, **kwargs):
14+
if index != -1 and len(args) > index:
15+
inputs = args[index]
16+
else:
17+
inputs = kwargs.get(param_name, None)
18+
19+
if isinstance(inputs, list):
20+
# Apply forward computation to each micro-batch
21+
results = []
22+
for input in inputs:
23+
if index != -1 and len(args) > index:
24+
args = args[0:index] + (input, ) + args[index + 1:]
25+
else:
26+
kwargs[param_name] = input
27+
result = func(self, *args, **kwargs)
28+
results.append(result)
29+
return results
30+
else:
31+
# If not a list, directly apply the forward computation
32+
return func(self, *args, **kwargs)
33+
34+
return wrapper
35+
36+
return decorator
37+
38+
39+
def split_batch(func, param_name, index=-1, num_splits=2):
40+
"""Decorator to split along the 0th dimension into a specified number of
41+
chunks."""
42+
43+
def wrapper(*args, **kwargs):
44+
if index != -1 and len(args) > index:
45+
inputs = args[index]
46+
else:
47+
inputs = kwargs.get(param_name, None)
48+
49+
if inputs is not None:
50+
split_inputs = list(torch.chunk(inputs, num_splits, dim=0))
51+
if index != -1 and len(args) > index:
52+
args = args[0:index] + (split_inputs, ) + args[index + 1:]
53+
else:
54+
kwargs[param_name] = split_inputs
55+
56+
results = func(*args, **kwargs)
57+
return torch.cat(results, dim=0)
58+
else:
59+
return func(*args, **kwargs)
60+
61+
return wrapper

requirements/runtime_cuda.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ safetensors
1616
sentencepiece
1717
shortuuid
1818
tiktoken
19-
torch<=2.5.1,>=2.0.0
20-
torchvision<=0.20.1,>=0.15.0
19+
torch<=2.6.0,>=2.0.0
20+
torchvision<=0.21.0,>=0.15.0
2121
transformers
22-
triton<=3.1.0,>=3.0.0; sys_platform == "linux"
22+
triton<=3.2.0,>=3.0.0; sys_platform == "linux"
2323
uvicorn

0 commit comments

Comments
 (0)
0