8000 Zmz/prefill without permute by dlblas by hellozmz · Pull Request #3430 · InternLM/lmdeploy · GitHub
[go: up one dir, main page]

Skip to content

Zmz/prefill without permute by dlblas #3430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add args prefill_without_permute
  • Loading branch information
hellozmz committed Apr 14, 2025
commit 47d9980ff9a5105cfc9168d5da95df09ba6c052d
11 changes: 11 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,17 @@ def eager_mode(parser):
help='Whether to enable eager mode. '
'If True, cuda graph would be disabled')

@staticmethod
def prefill_without_permute(parser):
"""Add argument prefill_without_permute to parser."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding dlblas' option, I recommend using env variables.


return parser.add_argument('--prefill-without-permute',
action='store_true',
default=False,
help='Whether to enable prefill_without_permute. '
'If True, the moe layer would not permute the input, '
'and would not unpermute the output')

@staticmethod
def communicator(parser):
return parser.add_argument('--communicator',
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ class PytorchEngineConfig:
bit, set it to 4 or 8, respectively
distributed_executor_backend (str): backend of distributed backend,
options: ['uni', 'mp', 'ray']
prefill_without_permute(bool): whether to use moe without permute.
Default to False.
"""
dtype: str = 'auto'
tp: int = 1
Expand All @@ -321,6 +323,7 @@ class PytorchEngineConfig:
revision: str = None
quant_policy: Literal[0, 4, 8] = 0
distributed_executor_backend: str = None
prefill_without_permute: bool = False

def __post_init__(self):
"""Check input validation."""
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.distributed as dist

from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPTokenDispatcherLowLatency, TokenDispatcherBuilder
from lmdeploy.pytorch.distributed import prefill_without_permute
from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8
from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
Expand All @@ -20,6 +21,7 @@
from ..moe import (FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl, FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder,
FusedMoEW8A8Impl)

is_prefill_without_permute = prefill_without_permute()
logger = get_logger('lmdeploy')


Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class DistConfig:
dp_rank: int = 0
world_size: int = None
attn_config: 'DistConfig' = None
prefill_without_permute: bool = False

def __post_init__(self):
"""post init."""
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class DistContext:
ep_gpu_group: dist.ProcessGroup = None
ep_gpu_groups: List[dist.ProcessGroup] = None
dist_config: DistConfig = None
prefill_without_permute: bool = False

@classmethod
def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = 'nccl'):
Expand All @@ -44,6 +45,7 @@ def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str =
ep = dist_config.ep
world_size = dist_config.world_size
dp_rank = dist_config.dp_rank
prefill_without_permute = dist_config.prefill_without_permute

if world_size == 1:
return DistContext(dist_config=dist_config)
Expand Down Expand Up @@ -104,6 +106,7 @@ def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str =
ep_gpu_group=ep_gpu_group,
ep_gpu_groups=ep_gpu_groups,
dist_config=dist_config,
prefill_without_permute=prefill_without_permute,
)
return context

Expand Down Expand Up @@ -181,6 +184,11 @@ def get_ep_world_rank():
return ctx.ep, ctx.ep_rank


def prefill_without_permute():
ctx = get_dist_manager().current_context()
return ctx.prefill_without_permute


def _check_group_device(device: str):
"""check group device."""
assert (device in ['cpu', 'gpu']), ('Expect process group device in ("cpu", "gpu"), '
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _build_dist_config(engine_config: PytorchEngineConfig):
tp=engine_config.tp,
ep=engine_config.ep,
dp_rank=engine_config.dp_rank,
prefill_without_permute=engine_config.prefill_without_permute,
)
return dist_config

Expand Down
0