Code for the paper "Merge, Then Compress: Demystify Efficient SMoE with Hints from Its Routing Policy"
- Authors: Pingzhi Li, Zhenyu Zhang, Prateek Yadav, Yi-Lin Sung, Yu Cheng, Mohit Bansal and Tianlong Chen
- Paper: arXiv
Update:
- 🚀 MC-SMoE now supports Mixtral-8x7B!
Sparsely activated Mixture-of-Experts (SMoE) has shown promise to scale up the learning capacity of neural networks,
however, they have issues like: (
conda create -n mcsmoe python=3.9 -y && conda activate mcsmoe
pip install -r requirements.txt
accelerate launch --config_file static/finetune_config.yaml \
mcsmoe/finetune-switch-transformers.py \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=64 \
--gradient_accumulation_steps=1 \
--num_epochs=20 \
--no_eval_until_epochs=1 \
--save_each_epoch=False \
--preprocessing_num_workers=8 \
--num_experts=32 \
--task="copa" \
--learning_rate=3e-5 \
--warmup_steps=16 \
--output_dir="results/copa/switch-32e"
python -u mcsmoe/permute-model.py \
--checkpoint="results/copa/switch-32e" \
--save_dir="results/copa/switch-32e-permuted"
accelerate launch --config_file static/finetune_config.yaml \
mcsmoe/msmoe-merging.py \
--per_device_train_batch_size=16 \ # ======== training arguments from here ========
--per_device_eval_batch_size=256 \
--gradient_accumulation_steps=1 \
--preprocessing_num_workers=8 \
--num_epochs=10 \
--num_eval_steps=100 \
--learning_rate=3e-5 \
--warmup_steps=16 \
--weight_decay=0.01 \
--kd_temperature=2 \
--mlm_lambda=1.0 \
--kd_lambda=0.2 \
--task="copa" \ # ======== merging arguments from here ========
--num_samples_for_merging=256 \
--similarity_base="router-logits" \ # for all available options refer to LEGAL_SIMILARITY_BASES in mcsmoe/merging/grouping.py
--num_groups=8 \ # average number of experts per SMoE layer
--globally_group=True \ # if True, apply adaptive merging ratio for each SMoE layer
--save_stable_rank=False \ # whether to save stable rank of each expert for analysis
--encoder_merging_layers="3,5,7,9,11" \ # encoder layer indices to be merged
--decoder_merging_layers="1,3,5,7,9,11" \ # decoder layer indices to be merged
--output_dir="results/copa/merged/" \ # M-SMoE checkpoint will be saved here
--teacher_checkpoint="results/copa/switch-32e-permuted" \ # KD teacher checkpoint, full SMoE
--student_checkpoint="results/copa/switch-32e-permuted" # KD student checkpoint, will be merged by M-SMoE
accelerate launch --config_file static/finetune_config.yaml \
--main_process_port 29510 mcsmoe/losparse-downstream.py \
--per_device_train_batch_size=16 \ # ======== training arguments from here ========
--per_device_eval_batch_size=256 \
--gradient_accumulation_steps=1 \
--preprocessing_num_workers=8 \
--num_epochs=50 \
--num_eval_steps=100 \
--learning_rate=3e-5 \
--warmup_steps=50 \
--weight_decay=0.01 \
--kd_temperature=2 \
--mlm_lambda=1.0 \
--kd_lambda=0.2 \
--hd_lambda=0.0 \
--task="copa" \ # ======== compression arguments from here ========
--output_dir="results/copa/switch-32e-merged-8e-compressed/" \ # MC-SMoE checkpoint will be saved here
--teacher_checkpoint="results/copa/switch-32e-permuted" \ # KD teacher checkpoint, full SMoE
--student_checkpoint="results/copa/switch-32e-merged-8e" \ # M-SMoE checkpoint, will be further compressed by MC-SMoE
--final_threshold=0.10 \ # average remaining ratio of S matrices in compression
--low_rank_factor=32 # low-rank factor for U, V matrices in compression
Please refer to scripts/t5 and scripts/gpt for more examples (e.g. baselines, ablations).
- Optimizer: AdamW
- Adam
$\epsilon$ :$1e-6$ - Adam
$\beta$ : ($0.9$ ,$0.98$ ) - Warm-up steps:
$16$ - Weight decay:
$0.01$ - LR scheduler: Linear decay
- KD
$\alpha$ :$0.2$ - KD
$T$ :$2.0$
Batch size | Learning rate | |
---|---|---|
SST-2 | ||
MRPC | ||
MultiRC | ||
COPA | ||
WinoGrande | ||
SQuAD | ||
WikiQA | ||
HotpotQA |
@misc{li2023merge,
title={Merge, Then Compress: Demystify Efficient SMoE with Hints from Its Routing Policy},
author={Pingzhi Li and Zhenyu Zhang and Prateek Yadav and Yi-Lin Sung and Yu Cheng and Mohit Bansal and Tianlong Chen},
year={2023},
eprint={2310.01334},
archivePrefix={arXiv},
primaryClass={cs.LG}
}