8000 Return ConstantVariable(None) from WithExitFunctionVariable.exit to p… · pytorch/pytorch@15a3f58 · GitHub
[go: up one dir, main page]

Skip to content

Commit 15a3f58

Browse files
janselwdziurdz
authored andcommitted
Return ConstantVariable(None) from WithExitFunctionVariable.exit to prevent NoneType crash inside autocast exception path (#152503)
Copy of #152013 with PR time benchmarks updated (regressions seem unrelated) Pull Request resolved: #152503 Approved by: https://github.com/anijain2305, https://github.com/Skylion007 Co-authored-by: Witold Dziurdz <wdziurdz@habana.ai>
1 parent 632b89a commit 15a3f58

File tree

3 files changed

+32
-9
lines changed

3 files changed

+32
-9
lines changed

benchmarks/dynamo/pr_time_benchmarks/expected_results.csv

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 8000 +1,4 @@
1-
add_loop_eager,compile_time_instruction_count,2944000000,0.015
1+
add_loop_eager,compile_time_instruction_count,2960000000,0.015
22

33

44

@@ -18,15 +18,15 @@ add_loop_inductor_gpu,compile_time_instruction_count,25505620920,0.015
1818

1919

2020

21-
basic_modules_ListOfLinears_eager,compile_time_instruction_count,999400000,0.015
21+
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1005000000,0.015
2222

2323

2424

2525
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17990000000,0.015
2626

2727

2828

29-
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16130000000,0.015
29+
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16220000000,0.015
3030

3131

3232

@@ -38,11 +38,11 @@ update_hint_regression,compile_time_instruction_count,1608000000,0.02
3838

3939

4040

41-
float_args,compile_time_instruction_count,441500000,0.015
41+
float_args,compile_time_instruction_count,439200000,0.015
4242

4343

4444

45-
sum_floordiv_regression,compile_time_instruction_count,985300000,0.015
45+
sum_floordiv_regression,compile_time_instruction_count,998400000,0.015
4646

4747

4848

@@ -54,11 +54,11 @@ symint_sum_loop,compile_time_instruction_count,4180000000,0.015
5454

5555

5656

57-
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2042000000,0.015
57+
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2074000000,0.015
5858

5959

6060

61-
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5884000000,0.015
61+
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5944000000,0.015
6262

6363

6464

@@ -70,8 +70,8 @@ aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1856000000,0.015
7070

7171

7272

73-
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3770000000,0.015
73+
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3795000000,0.015
7474

7575

7676

77-
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10200000000,0.015
77+
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10280000000,0.015

test/dynamo/test_exceptions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,28 @@ def fn(x):
128128
res = opt_fn(x)
129129
self.assertEqual(ref, res)
130130

131+
def test_autocast_with_exception(self):
132+
class Optimizer(torch.autograd.Function):
133+
@staticmethod
134+
def forward(ctx, x):
135+
raise NotImplementedError("Not implemented")
136+
137+
@staticmethod
138+
def backward(ctx, grad_out):
139+
return grad_out
140+
141+
@torch.compile
142+
def f(x: torch.Tensor):
143+
try:
144+
with torch.autocast(device_type="cpu", dtype=None):
145+
Optimizer.apply(x)
146+
except NotImplementedError:
147+
return x + 1
148+
149+
inp = torch.ones(3)
150+
out = f(inp)
151+
self.assertTrue(torch.equal(out, inp + 1))
152+
131153
@make_dynamo_test
132154
def test_propagate_exception_inside_ctx_manager(self):
133155
@contextlib.contextmanager

torch/_dynamo/variables/ctx_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ def exit(self, tx: "InstructionTranslator", *args):
888888
tx.output.create_node(
889889
"call_function", torch.amp._exit_autocast, (self.proxy,), {}
890890
)
891+
return variables.ConstantVariable.create(None)
891892

892893
def enter(self, tx):
893894
ctx = torch.amp._enter_autocast(*self.target_values)

0 commit comments

Comments
 (0)
0