8000 Diffusers Transformer Pipeline Produces ComplexDouble Tensors on MPS, Causing Conversion Error · Issue #10986 · huggingface/diffusers · GitHub
[go: up one dir, main page]

Skip to content

Diffusers Transformer Pipeline Produces ComplexDouble Tensors on MPS, Causing Conversion Error #10986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mozzipa opened this issue Mar 6, 2025 · 5 comments
Assignees
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@mozzipa
Copy link
mozzipa commented Mar 6, 2025

Describe the bug

When running the WanPipeline from diffusers on an MPS device, the pipeline fails with the error:

TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype.
Investigation indicates that in the transformer component (specifically in the rotary positional embedding function of WanRotaryPosEmbed), frequency tensors are computed using torch.float64 (and then converted to complex via torch.view_as_complex). This produces a ComplexDouble tensor (i.e. torch.complex128), which the MPS backend does not support.

Steps to Reproduce:

On an Apple Silicon Mac with MPS enabled, use diffusers (version 0.33.0.dev0) along with a recent PyTorch nightly (2.7.0.dev20250305).
Load a model pipeline as follows:

from diffusers import AutoencoderKLWan, WanPipeline
vae = AutoencoderKLWan.from_pretrained("<model_path>", subfolder="vae", torch_dtype=torch.float32).to("mps")
pipe = WanPipeline.from_pretrained("<model_path>", vae=vae, torch_dtype=torch.float32).to("mps")

Encode prompts and call the pipeline for inference.
The error occurs during the transformer’s forward pass in the rotary embedding function when it calls torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(...).

Expected Behavior:
When running on MPS, all computations (including the construction of rotary positional encodings) should use single-precision floats (and their corresponding complex type, i.e. torch.cfloat). In other words, the pipeline should ensure that no operations create or convert to ComplexDouble (complex128) tensors on MPS.

Workaround:
A temporary fix is to patch the helper functions used in computing the rotary embeddings so that they force the use of torch.float32. For instance, one workaround is to override the function (e.g., get_1d_rotary_pos_embed) that computes the frequency tensor so that it uses freqs_dtype=torch.float32 regardless of defaults. Additionally, patching WanRotaryPosEmbed.forward to cast its output to float32 (if it’s float64) avoids this error.

Reproduction

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import torch

# ------------------------------------------------------------------------------
# 1. Patch torch.view_as_complex to avoid creating ComplexDouble on MPS.
_orig_view_as_complex = torch.view_as_complex

def patched_view_as_complex(tensor):
    if tensor.device.type == "mps" and tensor.dtype == torch.float64:
        tensor = tensor.to(torch.float32)
    return _orig_view_as_complex(tensor)

torch.view_as_complex = patched_view_as_complex

# ------------------------------------------------------------------------------
# 2. Patch get_1d_rotary_pos_embed so that it always computes frequencies as float32.
try:
    from diffusers.models.transformers.embeddings import get_1d_rotary_pos_embed as original_get_1d_rotary_pos_embed
    import diffusers.models.transformers.embeddings as embeddings_mod
except ImportError:
    from diffusers.models.embeddings import get_1d_rotary_pos_embed as original_get_1d_rotary_pos_embed
    import diffusers.models.embeddings as embeddings_mod

def patched_get_1d_rotary_pos_embed(dim, max_seq_len, theta, use_real, repeat_interleave_real, freqs_dtype):
    return original_get_1d_rotary_pos_embed(
        dim, max_seq_len, theta, use_real, repeat_interleave_real, freqs_dtype=torch.float32
    )

embeddings_mod.get_1d_rotary_pos_embed = patched_get_1d_rotary_pos_embed

# ------------------------------------------------------------------------------
# 3. Patch WanRotaryPosEmbed.forward to ensure its output is float32.
from diffusers.models.transformers.transformer_wan import WanRotaryPosEmbed
_orig_rope_forward = WanRotaryPosEmbed.forward

def patched_rope_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    result = _orig_rope_forward(self, hidden_states)
    if hidden_states.device.type == "mps" and result.dtype == torch.float64:
        result = result.to(torch.float32)
    return result

WanRotaryPosEmbed.forward = patched_rope_forward

# ------------------------------------------------------------------------------
# Model setup.
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.utils import export_to_video

model_id = "~/Wan2.1-T2V-1.3B-Diffusers"

vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
vae.to("mps")

pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.float32)
pipe.to("mps")

pipe.enable_attention_slicing()

# ------------------------------------------------------------------------------
# Define prompts.
prompt = "A cat walks on the grass, realistic"
negative_prompt = (
    "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, "
    "worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, "
    "deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
)

# ------------------------------------------------------------------------------
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
    prompt=prompt,
    negative_prompt=negative_prompt,
    do_classifier_free_guidance=True,
    num_videos_per_prompt=1,
    max_sequence_length=226,
    device="mps",
    dtype=torch.float32
)

# ------------------------------------------------------------------------------
# Generate video frames.
output = pipe(
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_prompt_embeds,
    height=480,
    width=832,
    num_frames=81,
    guidance_scale=5.0
).frames[0]

# ------------------------------------------------------------------------------
# Export to video.
export_to_video(output, "output.mp4", fps=15)

Logs

System Info

diffusers: 0.33.0.dev0
PyTorch: 2.7.0.dev20250305 (nightly)
OS: macOS on Apple Silicon with MPS enabled
Other libraries: torchaudio 2.6.0.dev20250305, torchvision 0.22.0.dev20250305
Device: MPS

Who can help?

No response

@mozzipa mozzipa added the bug Something isn't working label Mar 6, 2025
@Vargol
Copy link
Vargol commented Mar 7, 2025

Came to report this myself. Here the traceback if that helps

  0%|                                                                                            | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/Volumes/SSD2TB/AI/Diffusers/wan.py", line 14, in <module>
    output = pipe(
             ^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/pipelines/wan/pipeline_wan.py", line 524, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py", line 418, in forward
    rotary_emb = self.rope(hidden_states)
                 ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py", line 205, in forward
    freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype.

@bghira
Copy link
Contributor
bghira commented Mar 23, 2025

the 'fix' is to switch the rotary embeds to fp32 in the transformer class:

            def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
                x_rotated = torch.view_as_complex(
                    hidden_states.to(
                        torch.float32
                        if torch.backends.mps.is_available()
                        else torch.float64
                    ).unflatten(3, (-1, 2))
                )

but it takes 200-600 seconds per step to generate. expect hour-long gen times

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 17, 2025
@bghira
Copy link
Contributor
bghira commented Apr 18, 2025

imo its fine to leave this one broken since mps lacks 3d conv accel anyway

@github-actions github-actions bot removed the stale Issues that haven't received updates label Apr 18, 2025
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

5 participants
0