-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
Describe the bug
Not long ago, a community PR set the Qwen3 MoE model as the leaf module by default #7604
For leaf modules, ZeRO3 manages all parameters within the module uniformly, requiring parameters to be gathered and released in a coordinated manner.
When initializing a DeepSpeed module, hooks are registered to the module's forward and backward functions. The function fetch_sub_module, which appears in error traces, is one of these hooks. Its purpose is to gather all required parameters within the module before forward or backward computation. For leaf modules, it recursively collects parameters from all submodules, so it is reasonable for MoE models to be designed as leaf modules themselves.
However, for the forward pass of Qwen3-30B-A3B, two tensors, final_hidden_states and router_logits, are returned.
When autograd executes in a multithreaded manner, the hook for Qwen3MoeSparseMoeBlock is triggered twice, resulting in concurrent execution across multiple threads, which leads to abnormal states.
To Reproduce
a simple demo
save the code bellow as test.py
run with: deepspeed --module test
import argparse
import deepspeed
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim import AdamW
from transformers import get_scheduler
from transformers.integrations.deepspeed import HfDeepSpeedConfig
import torch.nn.functional as F
# torch.autograd.set_multithreading_enabled(False)
class Qwen3MoeMLP(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.hidden_size = hidden_size
self.gate_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.down_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def forward(self, x):
down_proj = self.down_proj(self.gate_proj(x) * self.up_proj(x))
return down_proj
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(self, num_experts, num_experts_per_tok, hidden_size):
super().__init__()
self.num_experts = num_exper
7DE6
ts
self.top_k = num_experts_per_tok
# gating
self.gate = nn.Linear(hidden_size, num_experts, bias=False)
self.experts = nn.ModuleList(
[Qwen3MoeMLP(hidden_size) for _ in range(self.num_experts)]
)
self._z3_leaf = True
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
for exp_layer in self.experts:
hidden_states = exp_layer(hidden_states)
hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return hidden_states, router_logits
class MoeModel(nn.Module):
def __init__(self, num_experts, num_experts_per_tok, hidden_size):
super().__init__()
self.layers = nn.ModuleList(
[Qwen3MoeSparseMoeBlock(num_experts, num_experts_per_tok, hidden_size) for _ in range(2)]
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
output = layer(hidden_states)
hidden_states = output[0]
return output
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
hidden_size = 32
model = MoeModel(4, 2, hidden_size)
deepspeed.init_distributed()
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda', args.local_rank)
args.device = device
args.global_rank = dist.get_rank()
dist.barrier()
ds_config = {
'train_batch_size': None,
'train_micro_batch_size_per_gpu': 8,
'gradient_accumulation_steps': 1,
'steps_per_print': 10,
'zero_optimization': {
'stage': 3,
'offload_param': {
'device': 'none',
},
'offload_optimizer': {
'device': 'none',
},
'param_persistence_threshold': 1e4,
'max_live_parameters': 3e7,
'prefetch_bucket_size': 3e7,
'memory_efficient_linear': False,
'gather_16bit_weights_on_model_save': True,
},
'gradient_clipping': 1.0,
'prescale_gradients': False,
'wall_clock_breakdown': False,
}
_dstchf = HfDeepSpeedConfig(ds_config)
optimizer = AdamW(
[{'params': list(model.parameters()), 'weight_decay': 0.0}],
lr=1e-3,
betas=(0.9, 0.95),
)
lr_scheduler = get_scheduler(
name='cosine',
optimizer=optimizer,
num_warmup_steps=5,
num_training_steps=100,
)
model, *_ = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
config=ds_config,
lr_scheduler=lr_scheduler,
dist_init_required=True,
)
input_x = torch.randn(16, 128, hidden_size).to(device)
predicts = torch.randn(16, 128, hidden_size).to(device)
for i in range(10):
outputs, idx = model(input_x)
loss = nn.MSELoss()(outputs, predicts)
model.backward(loss)
model.step()
if __name__ == '__main__':
main()
Expected behavior
To make fetch_sub_module thread-safe, the management of the parameters state and the content of __inflight_param_registry is handled correctly.
A possible solution
ds_report output
- torch version .................... 2.5.0+cu124
- deepspeed info ................... 0.18.4, unknown, unknown
- torch cuda version ............... 12.4
- torch hip version ................ None
- nvcc version ..................... 12.4
- deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
- shared memory (/dev/shm) size .... 960.00 GB
Screenshots
System info (please complete the following information):
- OS: [Ubuntu 18.04]
- GPU count and types [8 machines with x8 H100s each]
- Python version [3.10.19]