-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
Here's an example:
import torch
lst = []
@torch.compile(backend="eager", fullgraph=True)
def f(x):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
out = torch.matmul(x, x)
end_event.record()
lst.append(start_event)
lst.append(end_event)
return out
x = torch.randn(5000, device='cuda')
out = f(x)
print(lst[0].elapsed_time(lst[1]))
without compile this prints the elapsed time between the two events.
55.96131134033203
with compile this gives an error:
Traceback (most recent call last):
File "/data/users/hirsheybar/a/pytorch/tmp6.py", line 20, in <module>
print(lst[0].elapsed_time(lst[1]))
File "/data/users/hirsheybar/a/pytorch/torch/cuda/streams.py", line 216, in elapsed_time
return super().elapsed_time(end_event)
ValueError: Both events must be recorded before calculating elapsed time.
Why? here's the generated dynamo graph + residual bytecode below. It looks like:
(1) dynamo handles the cuda.Event()
creation + list appending as compile-time constants, stashing them as globals and putting them in the list as residual bytecode
(2) dynamo also proxies the cuda.Event()
object into the graph, even though it is also treating it as a constant. The Event
object is unused though and gets DCEd
(3) dynamo also proxies the cuda.Event.record
calls into the graph, but they are DCEd
(4) at runtime, none of the logic to record the events runs (even if it did it wouldn't run because dynamo is ignoring the events that were proxied in the graph)
# graph
===== __compiled_fn_1 =====
/data/users/hirsheybar/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[5000][1]cuda:0"):
l_x_ = L_x_
# File: /data/users/hirsheybar/a/pytorch/tmp6.py:7 in f, code: start_event = torch.cuda.Event(enable_timing=True)
event = torch.cuda.streams.Event(enable_timing = True)
# File: /data/users/hirsheybar/a/pytorch/tmp6.py:8 in f, code: end_event = torch.cuda.Event(enable_timing=True)
event_1 = torch.cuda.streams.Event(enable_timing = True)
# File: /data/users/hirsheybar/a/pytorch/tmp6.py:10 in f, code: start_event.record()
record = event.record(); event = record = None
# File: /data/users/hirsheybar/a/pytorch/tmp6.py:11 in f, code: out = torch.matmul(x, x)
out: "f32[][]cuda:0" = torch.matmul(l_x_, l_x_); l_x_ = None
# File: /data/users/hirsheybar/a/pytorch/tmp6.py:12 in f, code: end_event.record()
record_1 = event_1.record(); event_1 = record_1 = None
return (out,)
# bytecode
DEBUG: MODIFIED BYTECODE f /data/users/hirsheybar/a/pytorch/tmp6.py line 5
5 0 LOAD_GLOBAL 9 (__compiled_fn_1)
2 LOAD_FAST 0 (x)
4 DUP_TOP
6 STORE_FAST 7 (tmp_3)
8 CALL_FUNCTION 1
10 STORE_FAST 4 (graph_out_0)
12 LOAD_FAST 4 (graph_out_0)
14 LOAD_CONST 3 (0)
16 BINARY_SUBSCR
18 LOAD_GLOBAL 7 (_event_140622680852736_c0)
20 LOAD_GLOBAL 8 (_event_140622672089088
7533
_c0)
22 BUILD_LIST 2
24 LOAD_GLOBAL 5 (lst)
26 DUP_TOP
28 STORE_FAST 6 (tmp_2)
30 LOAD_CONST 0 (None)
32 LOAD_CONST 0 (None)
34 BUILD_SLICE 2
36 STORE_SUBSCR
38 DELETE_FAST 4 (graph_out_0)
40 RETURN_VALUE
cc @ptrblck @msaroufim @eqy @jerryzh168 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames