-
Notifications
You must be signed in to change notification settings - Fork 6k
Diffusers Transformer Pipeline Produces ComplexDouble Tensors on MPS, Causing Conversion Error #10986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Came to report this myself. Here the traceback if that helps 0%| | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/Volumes/SSD2TB/AI/Diffusers/wan.py", line 14, in <module>
output = pipe(
^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/pipelines/wan/pipeline_wan.py", line 524, in __call__
noise_pred = self.transformer(
^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py", line 418, in forward
rotary_emb = self.rope(hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py", line 205, in forward
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype. |
the 'fix' is to switch the rotary embeds to fp32 in the transformer class: def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(
hidden_states.to(
torch.float32
if torch.backends.mps.is_available()
else torch.float64
).unflatten(3, (-1, 2))
) but it takes 200-600 seconds per step to generate. expect hour-long gen times |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
imo its fine to leave this one broken since mps lacks 3d conv accel anyway |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Describe the bug
When running the WanPipeline from diffusers on an MPS device, the pipeline fails with the error:
TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype.
Investigation indicates that in the transformer component (specifically in the rotary positional embedding function of WanRotaryPosEmbed), frequency tensors are computed using torch.float64 (and then converted to complex via torch.view_as_complex). This produces a ComplexDouble tensor (i.e. torch.complex128), which the MPS backend does not support.
Steps to Reproduce:
On an Apple Silicon Mac with MPS enabled, use diffusers (version 0.33.0.dev0) along with a recent PyTorch nightly (2.7.0.dev20250305).
Load a model pipeline as follows:
Encode prompts and call the pipeline for inference.
The error occurs during the transformer’s forward pass in the rotary embedding function when it calls
torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(...)
.Expected Behavior:
When running on MPS, all computations (including the construction of rotary positional encodings) should use single-precision floats (and their corresponding complex type, i.e. torch.cfloat). In other words, the pipeline should ensure that no operations create or convert to ComplexDouble (complex128) tensors on MPS.
Workaround:
A temporary fix is to patch the helper functions used in computing the rotary embeddings so that they force the use of torch.float32. For instance, one workaround is to override the function (e.g., get_1d_rotary_pos_embed) that computes the frequency tensor so that it uses freqs_dtype=torch.float32 regardless of defaults. Additionally, patching WanRotaryPosEmbed.forward to cast its output to float32 (if it’s float64) avoids this error.
Reproduction
Logs
System Info
diffusers: 0.33.0.dev0
PyTorch: 2.7.0.dev20250305 (nightly)
OS: macOS on Apple Silicon with MPS enabled
Other libraries: torchaudio 2.6.0.dev20250305, torchvision 0.22.0.dev20250305
Device: MPS
Who can help?
No response
The text was updated successfully, but these errors were encountered: