8000 [WIP] support megatron by Jintao-Huang · Pull Request #2885 · modelscope/ms-swift · GitHub
[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] support megatron #2885

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
Jintao-Huang committed Jan 9, 2025
commit f230e01875805b1aa5006b91b6a405bd4cdc7fbb
3 changes: 2 additions & 1 deletion swift/llm/argument/export_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import os
from dataclasses import dataclass
from typing import Literal, Optional
import torch.distributed as dist

import torch
import torch.distributed as dist

from swift.utils import get_logger
from .base_args import BaseArguments, to_abspath
Expand Down
3 changes: 1 addition & 2 deletions swift/megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.


from .utils import init_megatron_env
from .utils import init_megatron_env

init_megatron_env()
4 changes: 2 additions & 2 deletions swift/megatron/argument.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass, field, asdict
from typing import Optional, Literal, Dict, Any, Tuple, List
import sys
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple


@dataclass
Expand Down
9 changes: 2 additions & 7 deletions swift/megatron/convert/hf2megatron.py
8000
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import torch
from swift.llm import get_model_tokenizer, ExportArguments

from swift.llm import ExportArguments, get_model_tokenizer
from ..model import get_megatron_model_meta


def convert_hf2megatron(
args: ExportArguments
) -> None:
def convert_hf2megatron(args: ExportArguments) -> None:

from megatron.training.initialize import initialize_megatron
from megatron.training import get_args
Expand All @@ -19,7 +17,6 @@ def convert_hf2megatron(
megatron_model_meta.get_model_provider()
megatron_model_meta.load_config(hf_model.model_info)


initialize_megatron(args_defaults=extra_args)
args = get_args()
model_provider, convert_module = get_megatron_model_convert(args.model_type)
Expand All @@ -28,5 +25,3 @@ def convert_hf2megatron(
if save_torch_dtype is not None:
mg_model.to(save_torch_dtype)
convert_module.save_mgmodel(mg_model, args)


6 changes: 1 addition & 5 deletions swift/megatron/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .register import (
register_megatron_model, get_megatron_model_meta, MegatronModelMeta
)


from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model
4 changes: 3 additions & 1 deletion swift/megatron/model/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict

from swift.llm import ModelInfo
from typing import Dict, Any

config_mapping = {
'num_layers': ['num_hidden_layers'],
Expand All @@ -14,6 +15,7 @@
'attention_dropout': ['attention_dropout']
}


def load_config(model_info: ModelInfo) -> Dict[str, Any]:
model_config = model_info.config
megatron_config = {}
Expand Down
1 change: 0 additions & 1 deletion swift/megatron/model/constant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@

class MegatronModelType:
qwen = 'qwen'
22 changes: 10 additions & 12 deletions swift/megatron/model/qwen.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from swift.llm import ModelInfo, ModelGroup, Model
from .register import register_megatron_model, MegatronModelMeta
from .utils import get_model_provider
from .constant import MegatronModelType
from swift.llm import Model, ModelGroup, ModelInfo
from .config import load_config
from .constant import MegatronModelType
from .register import MegatronModelMeta, register_megatron_model
from .utils import get_model_provider


def load_qwen_config(model_info: ModelInfo):
args_config = load_config(model_info)
args_config['swiglu'] = True
return args_config


def convert_megatron2hf():
pass


def convert_hf2megatron():
pass


register_megatron_model(MegatronModelMeta(
MegatronModelType.qwen,[
register_megatron_model(
MegatronModelMeta(MegatronModelType.qwen, [
ModelGroup([
Model('Qwen/Qwen2.5-0.5B-Instruct', 'Qwen/Qwen2.5-0.5B-Instruct'),
Model('Qwen/Qwen2.5-1.5B-Instruct', 'Qwen/Qwen2.5-1.5B-Instruct'),
Expand All @@ -29,9 +32,4 @@ def convert_hf2megatron():
Model('Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-32B-Instruct'),
Model('Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen2.5-72B-Instruct'),
]),
],
convert_megatron2hf,
convert_hf2megatron,
get_model_provider,
load_qwen_config
))
], convert_megatron2hf, convert_hf2megatron, get_model_provider, load_qwen_config))
5 changes: 3 additions & 2 deletions swift/megatron/model/register.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Callable, List, Optional
from dataclasses import dataclass
from typing import Callable, List, Optional

from swift.llm import ModelGroup
from swift.llm.model.register import _get_matched_model_meta

Expand All @@ -17,6 +18,7 @@ class MegatronModelMeta:
get_model_provider: Callable
load_config: Callable


def register_megatron_model(model_meta: MegatronModelMeta, *, exist_ok: bool = False):
megatron_model_type = model_meta.megatron_model_type
if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING:
Expand All @@ -27,4 +29,3 @@ def register_megatron_model(model_meta: MegatronModelMeta, *, exist_ok: bool = F

def get_megatron_model_meta(model_id_or_path: str) -> Optional[MegatronModelMeta]:
return _get_matched_model_meta(model_id_or_path, MEGATRON_MODEL_MAPPING)

4 changes: 3 additions & 1 deletion swift/megatron/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.


def get_model_provider(gpt_model_cls, transformer_config_cls, layer_spec_module):

def model_provider(pre_process=True, post_process=True):
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
Expand All @@ -23,5 +25,5 @@ def model_provider(pre_process=True, post_process=True):
rotary_base=args.rotary_base,
seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor)
return model
return model_provider

return model_provider
13 changes: 5 additions & 8 deletions swift/megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import torch.distributed as dist

from swift.llm import LazyLLMDataset, Template, git_clone_github
from swift.utils import (
append_to_jsonl, get_dist_setting, get_logger, is_master, subprocess_run,
is_megatron_available, safe_ddp_context)

from swift.utils import (append_to_jsonl, get_dist_setting, get_logger, is_master, is_megatron_available,
safe_ddp_context, subprocess_run)

logger = get_logger()

Expand All @@ -33,17 +31,15 @@ def _rename_files():

def init_megatron_env() -> None:
if 'MEGATRON_LM_PATH' not in os.environ:
megatron_path = git_clone_github(
'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0')
megatron_path = git_clone_github('https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0')
else:
megatron_path = os.environ['MEGATRON_LM_PATH']
if not is_megatron_available():
subprocess_run(['pip', 'install', '-e', megatron_path])
sys.path.append(megatron_path)

if 'PAI_MEGATRON_PATCH_PATH' not in os.environ:
megatron_patch_path = git_clone_github(
'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1')
megatron_patch_path = git_clone_github('https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1')
else:
megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH']
sys.path.append(megatron_patch_path)
Expand All @@ -52,6 +48,7 @@ def init_megatron_env() -> None:
with safe_ddp_context():
_rename_files()


def patch_megatron(tokenizer):

def build_tokenizer(args):
Expand Down
Loading
0