8000 Dynamo compilation of SDXL, SDXL Refiner, SDXL ControlNet have no performance improvements after save/load · Issue #6686 · huggingface/diffusers · GitHub
[go: up one dir, main page]

Skip to content

Dynamo compilation of SDXL, SDXL Refiner, SDXL ControlNet have no performance improvements after save/load #6686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Alexadar opened this issue Jan 23, 2024 · 9 comments
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@Alexadar
Copy link
Alexadar commented Jan 23, 2024

Describe the bug

Dynamo compilation of SDXL, SDXL Refiner, SDXL ControlNet have no performance improvements if you

  1. Compile
  2. Save pretrained
  3. Load pretrained

Reproduction

SDXL:

import torch
from diffusers import StableDiffusionXLPipeline
txt2img_pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", 
    use_safetensors=True,
    variant="fp16",
    torch_dtype=torch.float16).to("cuda")
txt2img_pipeline.unet = torch.compile(txt2img_pipeline.unet, mode="reduce-overhead", fullgraph=True)
txt2img_pipeline("hello world", width=1024, height=1024)
txt2img_pipeline.save_pretrained("test_model", torch_dtype=torch.float16)
print("Test 1")
txt2img_pipeline("hello world", width=1024, height=1024)
print("Test 2")
txt2img_pipeline("hello world", width=1024, height=1024)
del txt2img_pipeline
import gc
gc.collect()
torch.cuda.empty_cache()
txt2img_pipeline = StableDiffusionXLPipeline.from_pretrained("test_model", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
txt2img_pipeline.unet.to(memory_format=torch.channels_last)
print("Test 1")
txt2img_pipeline("hello world", width=1024, height=1024)
print("Test 2")
txt2img_pipeline("hello world", width=1024, height=1024)

SDXL REFINER

import torch
from PIL import Image
from diffusers import DiffusionPipeline
refiner = DiffusionPipeline.from_pretrained(
        'stabilityai/stable-diffusion-xl-refiner-1.0',
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16",
    ).to(torch.device('cuda'))
refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
image = Image.new('RGB', (1024, 1024), color = 'red')
for i in range(1024):
    for j in range(1024):
        image.putpixel((i, j), (i % 256, j % 256, (i + j) % 256))
print("Test 1")
refiner("hello world", image=image, num_inference_steps=50)
print("Test 2")
refiner("hello world", image=image, num_inference_steps=50)
refiner.save_pretrained("test_model")
del refiner
import gc
gc.collect()
torch.cuda.empty_cache()
refiner = DiffusionPipeline.from_pretrained(
        'test_model',
        torch_dtype=torch.float16
    ).to(torch.device('cuda'))
print("Test 1")
refiner("hello world", image=image, num_inference_steps=50)
print("Test 2")
refiner("hello world", image=image, num_inference_steps=50)

SDXL Controlnet

import torch
from PIL import Image
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLControlNetPipeline, ControlNetModel
from diffusers.pipelines.controlnet import MultiControlNetModel

txt2img_pipeline =  StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", 
    use_safetensors=True,
    variant="fp16",
    torch_dtype=torch.float16).to("cuda")
controlnet_cn = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0", 
    torch_dtype=torch.float16, 
    variant="fp16",
    use_safetensors=True,
).to("cuda")
controlnet_depth = ControlNetModel.from_pretrained(
    "diffusers/controlnet-depth-sdxl-1.0", 
    torch_dtype=torch.float16, 
    variant="fp16",
    use_safetensors=True,
).to("cuda")
controlnet = MultiControlNetModel(
    [controlnet_cn,
    controlnet_depth]
).to("cuda")
multi_controlnet_compiled = torch.compile(controlnet, mode="reduce-overhead", fullgraph=True)
cn_img2imgpipeline = StableDiffusionXLControlNetPipeline(
    **txt2img_pipeline.components,
    controlnet=multi_controlnet_compiled,
)
image = Image.new('RGB', (1024, 1024), color = 'red')
for i in range(1024):
    for j in range(1024):
        image.putpixel((i, j), (i % 256, j % 256, (i + j) % 256))
print("Test 1")
cn_img2imgpipeline("hello world", image=[image]*2, control_image=[image]*2, num_inference_steps=50)
print("Test 2")
cn_img2imgpipeline("hello world", image=[image]*2, control_image=[image]*2, num_inference_steps=50)
multi_controlnet_compiled.save_pretrained("controlnet_compiled")
del multi_controlnet_compiled
import gc
gc.collect()
torch.cuda.empty_cache()
controlnet_compiled = ControlNetModel.from_pretrained("controlnet_compiled", torch_dtype=torch.float16).to("cuda")
controlnet_compiled_1 = ControlNetModel.from_pretrained("controlnet_compiled_1", torch_dtype=torch.float16).to("cuda")
multi_controlnet_compiled = MultiControlNetModel(
    [controlnet_compiled,
    controlnet_compiled_1]
).to("cuda")

cn_img2imgpipeline = StableDiffusionXLControlNetPipeline(
    **txt2img_pipeline.components,
    controlnet=multi_controlnet_compiled,
).to("cuda")
print("Test 1")
cn_img2imgpipeline("hello world", image=[image]*2, control_image=[image]*2, num_inference_steps=50)
print("Test 2")
cn_img2imgpipeline("hello world", image=[image]*2, control_image=[image]*2, num_inference_steps=50)

Each of the examples have no performance gains after loading of saved compiled weights

Logs

No response

System Info

  • diffusers version: 0.26.0.dev0
  • Platform: Linux-6.5.0-14-generic-x86_64-with-glibc2.17
  • Python version: 3.8.16
  • PyTorch version (GPU?): 2.1.0 (True)
  • Huggingface_hub version: 0.20.2
  • Transformers version: 4.35.2
  • Accelerate version: 0.24.1
  • xFormers version: 0.0.22.post7
  • Using GPU in script?: NVIDIA 3090
  • Using distributed or parallel set-up in script?: No

Who can help?

@DN6 @yiyixuxu @sayakpaul @patrickvonplaten

@Alexadar Alexadar added the bug Something isn't working label Jan 23, 2024
@AmericanPresidentJimmyCarter
Copy link
Contributor

Not a diffusers bug

pytorch/pytorch#101107

@Alexadar
Copy link
Author

Not a diffusers bug

pytorch/pytorch#101107

it seems os. SD not working too

import torch
import gc
from diffusers import StableDiffusionPipeline

pipe_txt2img = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    extract_ema=True,
    torch_dtype=torch.float16).to('cuda')
print("Test not compiled model")
pipe_txt2img(prompt="A photo of a cat", num_inference_steps=100)
print("Compile model")
pipe_txt2img.unet = torch.compile(pipe_txt2img.unet, mode="reduce-overhead", fullgraph=True)
pipe_txt2img(prompt="A photo of a cat", num_inference_steps=10)
print("Test compiled model")
pipe_txt2img(prompt="A photo of a cat", num_inference_steps=100)
print("save model")
pipe_txt2img.save_pretrained("test_model")

del pipe_txt2img
gc.collect()
torch.cuda.empty_cache()

print("Load model")
pipe_txt2img = StableDiffusionPipeline.from_pretrained(
    "test_model",
    torch_dtype=torch.float16).to('cuda')
print("Test compiled model")
pipe_txt2img(prompt="A photo of a cat", num_inference_steps=100)

Is there a workaround?

@sayakpaul
Copy link
Member

Does upgrading to PyTorch nightly work?

@DN6
Copy link
Collaborator
DN6 commented Jan 25, 2024

@Alexadar when calling save_pretrained. The compiled model is unwrapped before saving. Serializing an optimized model isn't currently supported:
https://pytorch.org/get-started/pytorch-2.0/#serialization

You will have to compile the module again I'm afraid.

@Alexadar
Copy link
Author

@Alexadar when calling save_pretrained. The compiled model is unwrapped before saving. Serializing an optimized model isn't currently supported: https://pytorch.org/get-started/pytorch-2.0/#serialization

You will have to compile the module again I'm afraid.

Maybe its possible to pickle serialised unets and load them manually ? I tried, no use, maybe because that pytorch bug needs some workaround

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 23, 2024
@sayakpaul
Copy link
Member

Seems like the issue is on the PyTorch side?

@yiyixuxu
Copy link
Collaborator

can we close this issue now?

@Alexadar Alexadar closed this as not planned Won't fix, can't repro, duplicate, stale Mar 1, 2024
@Alexadar
Copy link
Author
Alexadar commented Mar 1, 2024

can we close this issue now?

It seems, yep. Closed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

5 participants
0