8000 support ascend w8a8 graph_mode (#3267) · InternLM/lmdeploy@c02dd78 · GitHub
[go: up one dir, main page]

Skip to content

Commit c02dd78

Browse files
authored
support ascend w8a8 graph_mode (#3267)
* support ascend w8a8 graph_mode * support dlinfer smooth_quant * update code * add try_import_deeplink in utils * remove pytorch module in serve
1 parent 0c1b6ee commit c02dd78

File tree

14 files changed

+68
-65
lines changed
  • models
  • serve
  • 14 files changed

    +68
    -65
    lines changed

    lmdeploy/cli/lite.py

    Lines changed: 1 addition & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -92,6 +92,7 @@ def add_parser_smooth_quant():
    9292
    type=str,
    9393
    default='./work_dir',
    9494
    help='The working directory for outputs. defaults to "./work_dir"')
    95+
    parser.add_argument('--device', type=str, default='cuda', help='Device for weight quantization (cuda or npu)')
    9596
    ArgumentHelper.calib_dataset(parser)
    9697
    ArgumentHelper.calib_samples(parser)
    9798
    ArgumentHelper.calib_seqlen(parser)

    lmdeploy/lite/apis/auto_awq.py

    Lines changed: 1 addition & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -10,7 +10,7 @@
    1010

    1111
    from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, quant_weights, smooth_layers
    1212
    from lmdeploy.lite.utils import collect_target_modules
    13-
    from lmdeploy.pytorch.check_env import try_import_deeplink
    13+
    from lmdeploy.utils import try_import_deeplink
    1414

    1515
    from .calibrate import LAYER_TYPE_MAP, calibrate
    1616

    lmdeploy/lite/apis/smooth_quant.py

    Lines changed: 2 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -11,6 +11,7 @@
    1111
    from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, skipped_module, smooth_layers
    1212
    from lmdeploy.lite.utils import collect_target_modules
    1313
    from lmdeploy.pytorch.models import QLinear, QRMSNorm
    14+
    from lmdeploy.utils import try_import_deeplink
    1415

    1516

    1617
    def smooth_quant(model: str,
    @@ -26,6 +27,7 @@ def smooth_quant(model: str,
    2627
    quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'int8',
    2728
    revision: str = None,
    2829
    download_dir: str = None):
    30+
    try_import_deeplink(device)
    2931
    if quant_dtype == 'fp8':
    3032
    quant_dtype = 'float8_e4m3fn'
    3133

    Lines changed: 0 additions & 13 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1,14 +1 @@
    11
    # Copyright (c) OpenMMLab. All rights reserved.
    2-
    from .base import BaseChecker # noqa: F401
    3-
    4-
    5-
    def check_env_deeplink(device_type: str):
    6-
    """check Deeplink environment."""
    7-
    from .deeplink import DeeplinkChecker
    8-
    checker = DeeplinkChecker(device_type)
    9-
    checker.handle()
    10-
    11-
    12-
    def try_import_deeplink(device_type: str):
    13-
    """check Deeplink environment."""
    14-
    check_env_deeplink(device_type)
    Lines changed: 3 additions & 13 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1,12 +1,7 @@
    11
    # Copyright (c) OpenMMLab. All rights reserved.
    2-
    from .base import BaseChecker
    2+
    from lmdeploy.utils import try_import_deeplink
    33

    4-
    deeplink_device_type_list = [
    5-
    'ascend',
    6-
    'npu',
    7-
    'maca',
    8-
    'camb',
    9-
    ]
    4+
    from .base import BaseChecker
    105

    116

    127
    class DeeplinkChecker(BaseChecker):
    @@ -18,9 +13,4 @@ def __init__(self, device_type: str, logger=None) -> None:
    1813

    1914
    def check(self):
    2015
    """check."""
    21-
    device_type = self.device_type
    22-
    if device_type in deeplink_device_type_list:
    23-
    try:
    24-
    import dlinfer.framework.lmdeploy_ext # noqa: F401
    25-
    except Exception as e:
    26-
    self.log_and_exit(e, 'dlinfer', 'dlinfer is not available.')
    16+
    try_import_deeplink(self.device_type)

    lmdeploy/pytorch/kernels/cuda/__init__.py

    Lines changed: 2 additions & 3 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1,5 +1,5 @@
    11
    # Copyright (c) OpenMMLab. All rights reserved.
    2-
    2+
    from ..default.w8a8_kernels import per_channel_quant
    33
    from .alibi_pagedattention import alibi_paged_attention_fwd
    44
    from .apply_rotary_pos_emb import apply_rotary_pos_emb
    55
    from .fill_kv_cache import fill_kv_cache
    @@ -12,8 +12,7 @@
    1212
    from .pagedattention import paged_attention_fwd
    1313
    from .rms_norm import rms_norm
    1414
    from .w8a8_fused_moe import fused_moe_w8a8 10000
    15-
    from .w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_channel_quant, per_token_quant_int8,
    16-
    rms_norm_dynamic_quant)
    15+
    from .w8a8_triton_kernels import matmul_kernel_dynamic_quant, per_token_quant_int8, rms_norm_dynamic_quant
    1716

    1817
    __all__ = [
    1918
    'apply_rotary_pos_emb',

    lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py

    Lines changed: 1 addition & 28 deletions
    Original file line numberDiff line numberDiff line change
    @@ -5,6 +5,7 @@
    55
    import triton.language as tl
    66
    from packaging import version
    77

    8+
    from ..default.w8a8_kernels import per_channel_quant
    89
    from .triton_utils import get_kernel_meta
    910

    1011
    TRITON_VERSION = version.parse(triton.__version__)
    @@ -14,34 +15,6 @@
    1415
    tl_round = tl.math.round
    1516

    1617

    17-
    def per_channel_quant(x: torch.Tensor, dtype: torch.dtype):
    18-
    """Quantize the input tensor 'x' channel-wise using the given number of
    19-
    bits.
    20-
    21-
    Args:
    22-
    x (torch.Tensor): The input tensor to be quantized. Must be a
    23-
    2-dimensional tensor.
    24-
    dtype (torch.dtype): The data type to which the quantized tensor should
    25-
    be converted.
    26-
    27-
    Returns:
    28-
    tuple: A tuple containing two items -- the quantized tensor and
    29-
    the scale used for quantization.
    30-
    """
    31-
    assert x.ndim == 2
    32-
    x = x.to(torch.float32)
    33-
    x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0]
    34-
    qtype_info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
    35-
    q_max = qtype_info.max
    36-
    q_min = qtype_info.min
    37-
    scale = x_absmax / q_max
    38-
    x_q = x / scale
    39-
    if not dtype.is_floating_point:
    40-
    x_q = torch.round(x_q)
    41-
    x_q = x_q.clamp(q_min, q_max).to(dtype)
    42-
    return x_q, scale
    43-
    44-
    4518
    @triton.autotune(
    4619
    configs=[
    4720
    triton.Config({
    Lines changed: 2 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1,6 +1,8 @@
    11
    # Copyright (c) OpenMMLab. All rights reserved.
    22
    from .multinomial_sampling import multinomial_sampling
    3+
    from .w8a8_kernels import per_channel_quant
    34

    45
    __all__ = [
    56
    'multinomial_sampling',
    7+
    'per_channel_quant',
    68
    ]
    Lines changed: 30 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -0,0 +1,30 @@
    1+
    # Copyright (c) OpenMMLab. All rights reserved.
    2+
    import torch
    3+
    4+
    5+
    def per_channel_quant(x: torch.Tensor, dtype: torch.dtype):
    6+
    """Quantize the input tensor 'x' channel-wise using the given number of
    7+
    bits.
    8+
    9+
    Args:
    10+
    x (torch.Tensor): The input tensor to be quantized. Must be a
    11+
    2-dimensional tensor.
    12+
    dtype (torch.dtype): The data type to which the quantized tensor should
    13+
    be converted.
    14+
    15+
    Returns:
    16+
    tuple: A tuple containing two items -- the quantized tensor and
    17+
    the scale used for quantization.
    18+
    """
    19+
    assert x.ndim == 2
    20+
    x = x.to(torch.float32)
    21+
    x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0]
    22+
    qtype_info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
    23+
    q_max = qtype_info.max
    24+
    q_min = qtype_info.min
    25+
    scale = x_absmax / q_max
    26+
    x_q = x / scale
    27+
    if not dtype.is_floating_point:
    28+
    x_q = torch.round(x_q)
    29+
    x_q = x_q.clamp(q_min, q_max).to(dtype)
    30+
    return x_q, scale

    lmdeploy/pytorch/kernels/dispatcher.py

    Lines changed: 1 addition & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -64,6 +64,7 @@ def __init__(self, func_name: str):
    6464
    self.func_name = func_name
    6565
    self.dispatched_func = self.load_and_call
    6666
    self.device_manager.register_context_callback(self.device_callback)
    67+
    self.device_map = {'cuda': 'cuda', 'ascend': 'dlinfer', 'npu': 'dlinfer', 'maca': 'dlinfer', 'camb': 'dlinfer'}
    6768

    6869
    def device_callback(self, context: DeviceContext):
    6970
    """device context callback."""

    0 commit comments

    Comments
     (0)
    0