8000 [inductor] cudagraph error for individually compiled transformer blocks · Issue #152887 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] cudagraph error for individually compiled transformer blocks #152887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
StrongerXi opened this issue May 6, 2025 · 4 comments
Open
Labels
module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@StrongerXi
Copy link
Contributor
StrongerXi commented May 6, 2025

🐛 Describe the bug

This was first observed in #150706 (comment).

Note that if we uncomment the # Pass region, the error goes away.

import torch

def f(x):
    return x + 1

f = torch.compile(f, mode="reduce-overhead")

# Pass
#x = torch.ones(2, device='cuda')
#xx = torch.ones(2, device='cuda')
#y = f(x)
#z = f(xx)

# Fail
x = torch.ones(2, device='cuda')
y = f(x)
z = f(y)

Error logs

Traceback (most recent call last):
  File "/home/ryanguo99/scratch/cudagraph.py", line 17, in <module>
    z = f(y)
        ^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 678, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/scratch/cudagraph.py", line 3, in f
    def f(x):
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 872, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 1221, in forward
    return compiled_fn(full_args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 338, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
                            ^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 502, in wrapper
    return compiled_fn(runtime_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/output_code.py", line 584, in __call__
    return self.current_callable(inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/compile_fx.py", line 1572, in run
    return compiled_fn(new_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/cudagraph_trees.py", line 371, in deferred_cudagraphify
    return fn(inputs)
           ^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/utils.py", line 2570, in run
    return model(new_inputs)
           ^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/cudagraph_trees.py", line 1992, in run
    out = self._run(new_inputs, function_id)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/cudagraph_trees.py", line 2162, in _run
    return self.record_function(new_inputs, function_id)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/cudagraph_trees.py", line 2196, in record_function
    node = CUDAGraphNode(
           ^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/cudagraph_trees.py", line 950, in __init__
    recording_inputs = self._allocate_and_copy_recording_inputs(inputs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/cudagraph_trees.py", line 1666, in _allocate_and_copy_recording_inputs
    self._copy_inputs_and_remove_from_src(recording_inputs, inputs)
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/cudagraph_trees.py", line 1050, in _copy_inputs_and_remove_from_src
    torch._foreach_copy_(dst_tensors, src_tensors)
  File "/home/ryanguo99/repos/pytorch/torch/utils/_device.py", line 100, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/home/ryanguo99/scratch/cudagraph.py", line 31, in f
    return x + 1. To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

Versions

main aafe8a6, python 3.11

cc @mcarilli @ezyang @eellison @penguinwu @BoyuanFeng @chauhang @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

@masnesral masnesral added module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: inductor triage review labels May 6, 2025
@eellison
Copy link
Contributor
eellison commented May 7, 2025

Hi, as per the error message clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation. This is a fundamental limitation as of now.

@eellison eellison closed this as completed May 7, 2025
@eellison
Copy link
Contributor
eellison commented May 7, 2025
import torch

def f(x):
    return x + 1

f = torch.compile(f, mode="reduce-overhead")

# Pass
#x = torch.ones(2, device='cuda')
#xx = torch.ones(2, device='cuda')
#y = f(x)
#z = f(xx)

# Fail
for _ in range(3):
    torch.compiler.cudagraph_mark_step_begin()
    x = torch.ones(2, device='cuda')
    y = f(x)
    z = f(y)

@StrongerXi
Copy link
Contributor Author

@eellison cudagraph_mark_step_begin() fails, although clone works, in this repro which more faithfully captures the original model's computations:

import torch

@torch.compile(mode="reduce-overhead")
def f(x):
    return x + 1

x = torch.ones(2, device='cuda')
for _ in range(2):
    torch.compiler.cudagraph_mark_step_begin()
    x = f(x)
    #x = x.clone()

The bigger question around the original use case: #150706 (comment):

The breakage happens in framework code, so users can't easily unblock themselves.

We'd hit this x = f(x) pattern in pretty much all transformer blocks, and it'll break with this cudagraph limitation if we compile the individual block rather than full model.

Moreover, are these workarounds for cudagraph neutral for other cases? e.g., .clone() probably slows down eager, and does torch.compiler.cudagraph_mark_step_begin() have effect if we compile the entire model (meaning the for loop and the cudagraph_mark_step_begin call)?

I think for now I'll just advise users against using cudagraph when compiling individual transformer blocks, and we can come back to investigate better workarounds if there are more asks.

@StrongerXi StrongerXi reopened this May 7, 2025
@masnesral masnesral added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels May 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0