8000 `cuda.Event` handling in dynamo is broken · Issue #153058 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
cuda.Event handling in dynamo is broken #153058
@bdhirsh

Description

@bdhirsh

Here's an example:

import torch

lst = []

@torch.compile(backend="eager", fullgraph=True)
def f(x):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    out = torch.matmul(x, x)
    end_event.record()

    lst.append(start_event)
    lst.append(end_event)
    return out

x = torch.randn(5000, device='cuda')
out = f(x)
print(lst[0].elapsed_time(lst[1]))

without compile this prints the elapsed time between the two events.

55.96131134033203

with compile this gives an error:

Traceback (most recent call last):
  File "/data/users/hirsheybar/a/pytorch/tmp6.py", line 20, in <module>
    print(lst[0].elapsed_time(lst[1]))
  File "/data/users/hirsheybar/a/pytorch/torch/cuda/streams.py", line 216, in elapsed_time
    return super().elapsed_time(end_event)
ValueError: Both events must be recorded before calculating elapsed time.

Why? here's the generated dynamo graph + residual bytecode below. It looks like:

(1) dynamo handles the cuda.Event() creation + list appending as compile-time constants, stashing them as globals and putting them in the list as residual bytecode

(2) dynamo also proxies the cuda.Event() object into the graph, even though it is also treating it as a constant. The Event object is unused though and gets DCEd

(3) dynamo also proxies the cuda.Event.record calls into the graph, but they are DCEd

(4) at runtime, none of the logic to record the events runs (even if it did it wouldn't run because dynamo is ignoring the events that were proxied in the graph)

# graph
 ===== __compiled_fn_1 =====
 /data/users/hirsheybar/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[5000][1]cuda:0"):
        l_x_ = L_x_

         # File: /data/users/hirsheybar/a/pytorch/tmp6.py:7 in f, code: start_event = torch.cuda.Event(enable_timing=True)
        event = torch.cuda.streams.Event(enable_timing = True)

         # File: /data/users/hirsheybar/a/pytorch/tmp6.py:8 in f, code: end_event = torch.cuda.Event(enable_timing=True)
        event_1 = torch.cuda.streams.Event(enable_timing = True)

         # File: /data/users/hirsheybar/a/pytorch/tmp6.py:10 in f, code: start_event.record()
        record = event.record();  event = record = None

         # File: /data/users/hirsheybar/a/pytorch/tmp6.py:11 in f, code: out = torch.matmul(x, x)
        out: "f32[][]cuda:0" = torch.matmul(l_x_, l_x_);  l_x_ = None

         # File: /data/users/hirsheybar/a/pytorch/tmp6.py:12 in f, code: end_event.record()
        record_1 = event_1.record();  event_1 = record_1 = None
        return (out,)

# bytecode
DEBUG: MODIFIED BYTECODE f /data/users/hirsheybar/a/pytorch/tmp6.py line 5
  5           0 LOAD_GLOBAL              9 (__compiled_fn_1)
              2 LOAD_FAST                0 (x)
              4 DUP_TOP
              6 STORE_FAST               7 (tmp_3)
              8 CALL_FUNCTION            1
             10 STORE_FAST               4 (graph_out_0)
             12 LOAD_FAST                4 (graph_out_0)
             14 LOAD_CONST               3 (0)
             16 BINARY_SUBSCR
             18 LOAD_GLOBAL              7 (_event_140622680852736_c0)
             20 LOAD_GLOBAL              8 (_event_140622672089088
7533
_c0)
             22 BUILD_LIST               2
             24 LOAD_GLOBAL              5 (lst)
             26 DUP_TOP
             28 STORE_FAST               6 (tmp_2)
             30 LOAD_CONST               0 (None)
             32 LOAD_CONST               0 (None)
             34 BUILD_SLICE              2
             36 STORE_SUBSCR
             38 DELETE_FAST              4 (graph_out_0)
             40 RETURN_VALUE

cc @ptrblck @msaroufim @eqy @jerryzh168 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0