8000 [dynamo] Constant fold torch.autograd._profiler_enabled (#158482) · pytorch/pytorch@d7e1b8b · GitHub
[go: up one dir, main page]

Skip to content

Commit d7e1b8b

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo] Constant fold torch.autograd._profiler_enabled (#158482)
Pull Request resolved: #158482 Approved by: https://github.com/williamwen42, https://github.com/StrongerXi
1 parent b6454a9 commit d7e1b8b

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

test/dynamo/test_profiler.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,47 @@ def fn(x, y):
192192
],
193193
)
194194

195+
def test_profiler_enabled(self):
196+
def fn(x):
197+
x = torch.sin(x)
198+
if torch.autograd._profiler_enabled():
199+
return torch.cos(x)
200+
else:
201+
return torch.sigmoid(x)
202+
203+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
204+
x = torch.randn(4)
205+
206+
ref = fn(x)
207+
res = opt_fn(x)
208+
self.assertEqual(ref, res)
209+
210+
with torch.autograd.profiler.profile():
211+
ref = fn(x)
212+
res = opt_fn(x)
213+
self.assertEqual(ref, res)
214+
215+
def test_profiler_record_function_ignore(self):
216+
def fn(x):
217+
x = torch.sin(x)
218+
if torch.autograd._profiler_enabled():
219+
with torch.autograd.profiler.record_function("dummy"):
220+
return torch.cos(x)
221+
else:
222+
return torch.sigmoid(x)
223+
224+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
225+
x = torch.randn(4)
226+
227+
ref = fn(x)
228+
res = opt_fn(x)
229+
self.assertEqual(ref, res)
230+
231+
with torch.autograd.profiler.profile():
232+
ref = fn(x)
233+
res = opt_fn(x)
234+
self.assertEqual(ref, res)
235+
195236

196237
if __name__ == "__main__":
197238
from torch._dynamo.test_case import run_tests

torch/_dynamo/trace_rules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@
176176
"torch.compiler.is_compiling": TorchInGraphFunctionVariable,
177177
"torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable,
178178
"torch.compiler.is_exporting": TorchInGraphFunctionVariable,
179-
"torch.autograd._profiler_enabled": SkipFunctionVariable,
180179
"torch._C._to_dlpack": SkipFunctionVariable,
181180
"torch.to_dlpack": SkipFunctionVariable,
182181
# We graph break on RNG state setters or getters like
@@ -2434,6 +2433,7 @@
24342433
"torch.atleast_3d",
24352434
"torch.autograd._calculate_shape",
24362435
"torch.autograd._is_checkpoint_valid",
2436+
"torch.autograd._profiler_enabled",
24372437
"torch.autograd._make_grads",
24382438
"torch.autograd._register_py_tensor_class_for_device",
24392439
"torch.autograd._tensor_or_tensors_to_tuple",

torch/_dynamo/variables/torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@
142142
torch.cuda.is_initialized,
143143
torch.xpu.current_device,
144144
torch.xpu.is_initialized,
145+
torch.autograd._profiler_enabled,
145146
]
146147

147148
constant_fold_functions = [

0 commit comments

Comments
 (0)