8000 [ued] HF diffusers pipeline `enable_cpu_offload` errors or graph breaks with a `torch.compile`-ed transformer · Issue #150711 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[ued] HF diffusers pipeline enable_cpu_offload errors or graph breaks with a torch.compile-ed transformer #150711
Closed
@StrongerXi

Description

@StrongerXi

🐛 Describe the bug

No response

Error logs

Non-inplace torch.compile repro

import torch
from diffusers import (
    AuraFlowPipeline,
    GGUFQuantizationConfig,
    AuraFlowTransformer2DModel,
)

transformer = AuraFlowTransformer2DModel.from_single_file(
    "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
    "fal/AuraFlow-v0.3",
    torch_dtype=torch.bfloat16,
    transformer=transformer,
).to("cuda")

pipeline.transformer = torch.compile(pipeline.transformer, fullgraph=True)
pipeline.enable_model_cpu_offload()
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)

output:

Traceback (most recent call last):
  File "/home/ryanguo99/scratch/recompiles.py", line 21, in <module>
    pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
  File "/home/ryanguo99/repos/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/diffusers/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py", line 636, in __call__
    self.maybe_free_model_hooks()
  File "/home/ryanguo99/repos/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 1189, in maybe_free_model_hooks
    self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
  File "/home/ryanguo99/repos/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 1111, in enable_model_cpu_offload
    self.remove_all_hooks()
  File "/home/ryanguo99/repos/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 1076, in remove_all_hooks
    accelerate.hooks.remove_hook_from_module(model, recurse=True)
  File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/accelerate/hooks.py", line 194, in remove_hook_from_module
    delattr(module, "_hf_hook")
  File "/home/ryanguo99/repos/pytorch/torch/nn/modules/module.py", line 2052, in __delattr__
    super().__delattr__(name)
AttributeError: 'OptimizedModule' object has no attribute '_hf_hook'

Inplace torch.compile repro

import torch
from diffusers import (
    AuraFlowPipeline,
    GGUFQuantizationConfig,
    AuraFlowTransformer2DModel,
)

transformer = AuraFlowTransformer2DModel.from_single_file(
    "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
    "fal/AuraFlow-v0.3",
    torch_dtype=torch.bfloat16,
    transformer=transformer,
).to("cuda")

pipeline.transformer.compile(fullgraph=True)
pipeline.enable_model_cpu_offload()
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
pipeline("A cute pony", width=256, height=256, num_inference_steps=1)

Output:

Traceback (most recent call last):
  File "/home/ryanguo99/scratch/recompiles.py", line 21, in <module>
    pipeline("A cute pony", width=256, height=256, num_inference_steps=1)
  File "/home/ryanguo99/repos/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/diffusers/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py", line 593, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 667, in _fn
    raise e.with_traceback(None) from e.__cause__
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor device call_function <built-in function getitem>

from user code:
   File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/accelerate/hooks.py", line 161, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
  File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/accelerate/hooks.py", line 690, in pre_forward
    self.prev_module_hook.offload()
  File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/accelerate/hooks.py", line 706, in offload
    self.hook.init_hook(self.model)
  File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/accelerate/hooks.py", line 686, in init_hook
    return module.to("cpu")
  File "/home/ryanguo99/.conda/envs/pt311/lib/python3.11/site-packages/transformers/modeling_utils.py", line 3162, in to
    return super().to(*args, **kwargs)
  File "/home/ryanguo99/repos/pytorch/torch/nn/modules/module.py", line 1314, in to
    device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (ple

Note that fullgraph=False works in this case and gives some speed-up (consistently >= 30%, but quite flaky from crude experiments), although I still see a few graph breaks and recompiles: tlparse.

Versions

bb98749, python 3.11

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0