8000 [ued] HF diffusers pipeline `enable_cpu_offload` errors or graph breaks with a `torch.compile`-ed transformer · Issue #150711 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ued] HF diffusers pipeline enable_cpu_offload errors or graph breaks with a torch.compile-ed transformer #150711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
StrongerXi opened this issue Apr 4, 2025 · 1 comment
Assignees
Labels
dynamo-triage-jan2025 module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@StrongerXi
Copy link
Contributor
StrongerXi commented Apr 4, 2025

🐛 Describe the bug

No response

Error logs

Non-inplace torch.compile repro

import torch
from diffusers import (
    AuraFlowPipeline,
    GGUFQuantizationConfig,
    AuraFlowTransformer2DModel,
)

transformer = AuraFlowTransformer2DModel.from_single_file(
    "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
    "fal/AuraFlow-v0.3",
    torch_dtype=torch.bfloat16,
    transformer=transformer,
).to("cuda")

pipeline.transformer = torch.compile(pipeline.transformer, fullgraph=True)
pipeline.enable_model_cpu_offload()
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)

output:

Traceback (most recent call last):
  File "/home/ryanguo99/scratch/recompiles.py", line 21, in <module>
    pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
  File "/home/ryanguo99/repos/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/diffusers/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py", line 636, in __call__
    self.maybe_free_model_hooks()
  File "/home/ryanguo99/repos/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 1189, in maybe_free_model_hooks
    self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
  File "/home/ryanguo99/repos/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 1111, in enable_model_cpu_offload
    self.remove_all_hooks()
  File "/home/ryanguo99/repos/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 1076, in remove_all_hooks
    accelerate.hooks.remove_hook_from_module(model, recurse=True)
  File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/accelerate/hooks.py", line 194, in remove_hook_from_module
    delattr(module, "_hf_hook")
  File "/home/ryanguo99/repos/pytorch/torch/nn/modules/module.py", line 2052, in __delattr__
    super().__delattr__(name)
AttributeError: 'OptimizedModule' object has no attribute '_hf_hook'

Inplace torch.compile repro

import torch
from diffusers import (
    AuraFlowPipeline,
    GGUFQuantizationConfig,
    AuraFlowTransformer2DModel,
)

transformer = AuraFlowTransformer2DModel.from_single_file(
    "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
    "fal/AuraFlow-v0.3",
    torch_dtype=torch.bfloat16,
    transformer=transformer,
).to("cuda")

pipeline.transformer.compile(fullgraph=True)
pipeline.enable_model_cpu_offload()
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)

Output:

Traceback (most recent call last):
  File "/home/ryanguo99/scratch/recompiles.py", line 21, in <module>
    pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
  File "/home/ryanguo99/repos/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/diffusers/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py", line 593, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 667, in _fn
    raise e.with_traceback(None) from e.__cause__
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor device call_function <built-in function getitem>

from user code:
   File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/accelerate/hooks.py", line 161, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
  File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/accelerate/hooks.py", line 690, in pre_forward
    self.prev_module_hook.offload()
  File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/accelerate/hooks.py", line 706, in offload
    self.hook.init_hook(self.model)
  File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/accelerate/hooks.py", line 686, in init_hook
    return module.to("cpu")
  File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/transformers/modeling_utils.py", line 3162, in to
    return super().to(*args, **kwargs)
  File "/home/ryanguo99/repos/pytorch/torch/nn/modules/module.py", line 1314, in to
    device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (ple

Note that fullgraph=False works in this case and gives some speed-up (consistently >= 30%, but quite flaky from crude experiments), although I still see a few graph breaks and recompiles: tlparse.

Versions

bb98749, python 3.11

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

@xmfan xmfan added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo labels Apr 7, 2025
@StrongerXi StrongerXi self-assigned this May 2, 2025
StrongerXi added a commit that referenced this issue May 2, 2025
This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

[ghstack-poisoned]
StrongerXi added a commit that referenced this issue May 2, 2025
This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

ghstack-source-id: 26fd2bb
Pull Request resolved: #152741
StrongerXi added a commit that referenced this issue May 3, 2025
…ule)`"

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
StrongerXi added a commit that referenced this issue May 3, 2025
This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

ghstack-source-id: 1bc044f
Pull Request resolved: #152741
@StrongerXi
Copy link
Contributor Author

The out-of-place compile is passing, but in-place compile (with fullgraph=False) seems to be erroring now, working on a repro.

@StrongerXi StrongerXi reopened this May 6, 2025
StrongerXi added a commit that referenced this issue May 9, 2025
…orch.compile(module)`"

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
StrongerXi added a commit that referenced this issue May 9, 2025
…ule)`"

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
StrongerXi added a commit that referenced this issue May 9, 2025
This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

ghstack-source-id: efd1d6c
Pull Request resolved: #152741
StrongerXi added a commit that referenced this issue May 12, 2025
…orch.compile(module)`"

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
StrongerXi added a commit that referenced this issue May 12, 2025
…ule)`"

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
StrongerXi added a commit that referenced this issue May 12, 2025
This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

ghstack-source-id: 6e0a7f9
Pull Request resolved: #152741
StrongerXi added a commit that referenced this issue May 13, 2025
…orch.compile(module)`"

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
StrongerXi added a commit that referenced this issue May 13, 2025
…orch.compile(module)`"

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
StrongerXi added a commit that referenced this issue May 13, 2025
This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

ghstack-source-id: 21a4df4
Pull Request resolved: #152741
StrongerXi added a commit that referenced this issue May 13, 2025
…ule)`"

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
StrongerXi added a commit that referenced this issue May 13, 2025
…orch.compile(module)`"

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
StrongerXi added a commit that referenced this issue May 13, 2025
…ule)`"

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
StrongerXi added a commit that referenced this issue May 13, 2025
This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

ghstack-source-id: b2ab7f7
Pull Request resolved: #152741
@StrongerXi StrongerXi reopened this May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-triage-jan2025 module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0