diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 77902dcf5852..3a9d8c32cadc 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -389,7 +389,9 @@ def forward(self, x): def moe_infer(self, x, flat_expert_indices, flat_expert_weights): expert_cache = torch.zeros_like(x) idxs = flat_expert_indices.argsort() - tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + count_freq = torch.bincount(flat_expert_indices, minlength=self.num_activated_experts) + tokens_per_expert = count_freq.cumsum(dim=0) + token_idxs = idxs // self.num_activated_experts for i, end_idx in enumerate(tokens_per_expert): start_idx = 0 if i == 0 else tokens_per_expert[i - 1] diff --git a/tests/models/transformers/test_models_transformer_hidream.py b/tests/models/transformers/test_models_transformer_hidream.py index fa0fa5123ac8..14336713a358 100644 --- a/tests/models/transformers/test_models_transformer_hidream.py +++ b/tests/models/transformers/test_models_transformer_hidream.py @@ -20,6 +20,10 @@ from diffusers import HiDreamImageTransformer2DModel from diffusers.utils.testing_utils import ( enable_full_determinism, + is_torch_compile, + require_torch_2, + require_torch_gpu, + slow, torch_device, ) @@ -94,3 +98,20 @@ def test_set_attn_processor_for_determinism(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"HiDreamImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @require_torch_gpu + @require_torch_2 + @is_torch_compile + @slow + def test_torch_compile_recompilation_and_graph_break(self): + torch._dynamo.reset() + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict).to(torch_device) + model = torch.compile(model, fullgraph=True) + + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict)