|
4 | 4 |
|
5 | 5 | import torch
|
6 | 6 | import torch.nn.functional as F
|
| 7 | +from packaging import version |
7 | 8 | from torch import nn
|
8 | 9 | from transformers.configuration_utils import PretrainedConfig
|
9 | 10 |
|
| 11 | +from lmdeploy.pytorch.distributed import get_world_rank |
10 | 12 | from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
|
11 | 13 | from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
|
| 14 | +from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch |
12 | 15 | from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
|
13 | 16 | from lmdeploy.pytorch.nn import LayerNorm, RMSNorm
|
14 | 17 | from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear
|
@@ -205,15 +208,22 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
|
205 | 208 | self.ls1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))
|
206 | 209 | self.ls2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))
|
207 | 210 |
|
208 |
| - def forward( |
209 |
| - self, |
210 |
| - hidden_states: torch.Tensor, |
211 |
| - ): |
212 |
| - """forward.""" |
213 |
| - hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1 |
| 211 | + @enable_micro_batch(param_name='hidden_states', index=0) |
| 212 | + def _attn(self, hidden_states): |
| 213 | + hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states[0].dtype)) * self.ls1 |
| 214 | + return hidden_states |
214 | 215 |
|
| 216 | + @enable_micro_batch(param_name='hidden_states', index=0) |
| 217 | + def _mlp(self, hidden_states): |
215 | 218 | hidden_states = hidden_states + self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2
|
| 219 | + return hidden_states |
216 | 220 |
|
| 221 | + def forward( |
| 222 | + self, |
| 223 | + hidden_states, |
| 224 | + ): |
| 225 | + hidden_states = self._attn(hidden_states) |
| 226 | + hidden_states = self._mlp(hidden_states) |
217 | 227 | return hidden_states
|
218 | 228 |
|
219 | 229 |
|
@@ -306,6 +316,33 @@ def __init__(self,
|
306 | 316 |
|
307 | 317 | self.input_processor = InternVLInputProcessor(self.config, dtype)
|
308 | 318 |
|
| 319 | + self.compile_vit = False |
| 320 | + |
| 321 | + def compile_model(self): |
| 322 | + torch_version = version.parse(torch.__version__) |
| 323 | + if torch_version < version.parse('2.5.0'): |
| 324 | + return |
| 325 | + |
| 326 | + world_size, _ = get_world_rank() |
| 327 | + if torch_version >= version.parse('2.6.0') and world_size > 1: |
| 328 | + torch._inductor.config.reorder_for_compute_comm_overlap = True |
| 329 | + if isinstance(self.vision_model, InternVisionModel): |
| 330 | + self.vision_model.encoder.forward = split_batch(self.vision_model.encoder.forward, |
| 331 | + 'inputs_embeds', |
| 332 | + index=0) |
| 333 | + |
| 334 | + self.extract_feature = torch.compile(self.extract_feature, mode='max-autotune') |
| 335 | + self.compile_vit = True |
| 336 | + self.has_compiled_vit = False |
| 337 | + |
| 338 | + def _mark_dynamic_once(self, pixel_values, dims): |
| 339 | + """call torch._dynamo.mark_dynamic to avoid recompile.""" |
| 340 | + if not self.compile_vit or self.has_compiled_vit or pixel_values is None: |
| 341 | + return |
| 342 | + |
| 343 | + torch._dynamo.mark_dynamic(pixel_values, dims) |
| 344 | + self.has_compiled_vit = True |
| 345 | + |
309 | 346 | def pixel_shuffle(self, x, scale_factor=0.5):
|
310 | 347 | n, w, h, c = x.size()
|
311 | 348 | # N, W, H, C --> N, W, H * scale, C // scale
|
@@ -350,6 +387,7 @@ def forward(
|
350 | 387 | ):
|
351 | 388 | if inputs_embeds is None and pixel_values is not None:
|
352 | 389 | # extract feature
|
| 390 | + self._mark_dynamic_once(pixel_values, [0]) |
353 | 391 | vit_embeds = self.extract_feature(pixel_values)
|
354 | 392 | lang_embeds = self.language_model.get_input_embeddings()(input_ids)
|
355 | 393 | lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)
|
|
0 commit comments