Description
🐛 Describe the bug
Since commit 511d0dd
(PT 2.7 vs 2.6) Dynamo crashes when an exception is raised inside an autocast context-manager, emitting:
V0423 13:34:49.292000 1586012 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE RAISE_VARARGS 1 [ExceptionVariable(<class 'NotImplementedError'>)]
V0423 13:34:49.292000 1586012 torch/_dynamo/symbolic_convert.py:3908] [0/0] Observed exception DURING INLING <code object forward at 0x7f9b195392c0, file "src/test.py", line 6> : raised exception ExceptionVariable(<class 'NotImplementedError'>)
V0423 13:34:49.293000 1586012 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source] TRACE starts_line test.py:20 in f (Repro.test_autocast_with_exception.f)
V0423 13:34:49.293000 1586012 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source] with torch.autocast(device_type="cpu", dtype=None):
V0423 13:34:49.293000 1586012 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE WITH_EXCEPT_START None [WithExitFunctionVariable(), ConstantVariable(NoneType: None), ConstantVariable(NoneType: None), ConstantVariable(NoneType: None), UnknownVariable(), ExceptionVariable(<class 'NotImplementedError'>), BuiltinVariable(NotImplementedError)]
V0423 13:34:49.293000 1586012 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE POP_JUMP_IF_TRUE 52 [WithExitFunctionVariable(), ConstantVariable(NoneType: None), ConstantVariable(NoneType: None), ConstantVariable(NoneType: None), UnknownVariable(), ExceptionVariable(<class 'NotImplementedError'>), BuiltinVariable(NotImplementedError), None]
I0423 13:34:49.294000 1586012 torch/_dynamo/convert_frame.py:1121] [0/0] run_gc_after_compile: running gc
E
======================================================================
ERROR: test_autocast_with_exception (__main__.Repro)
----------------------------------------------------------------------
Traceback (most recent call last):
File "src/test.py", line 26, in test_autocast_with_exception
out = f(inp)
File "lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
return fn(*args, **kwargs)
File "lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__
return self._torchdynamo_orig_callable(
File "lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1213, in __call__
result = self._inner_convert(
File "lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 598, in __call__
return _compile(
File "lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1110, in _compile
raise InternalTorchDynamoError(
File "lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "lib/python3.10/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
return function(*args, **kwargs)
File "lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 797, in _compile_inner
out_code = transform_code_object(code, transform)
B633
File "lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1422, in transform_code_object
transformations(instructions, code_options)
File "lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 257, in _fn
return fn(*args, **kwargs)
File "lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 715, in transform
tracer.run()
File "lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3500, in run
super().run()
File "lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
while self.step():
File "lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
self.dispatch_table[inst.opcode](self, inst)
File "lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 698, in inner
if value.is_python_constant():
torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'NoneType' object has no attribute 'is_python_constant'
from user code:
File "src/test.py", line 20, in f
with torch.autocast(device_type="cpu", dtype=None):
This did not happen on 2.6, so it looks like a regression or at least the commit uncovered an un-handled corner case.
Minimal repro (passes on 2.6, fails on 2.7) below
import torch
import unittest
class Boom(torch.autograd.Function):
def forward(ctx, x):
raise NotImplementedError("boom")
@staticmethod
def backward(ctx, grad_out):
return grad_out
class Repro(unittest.TestCase):
def test_autocast_with_exception(self):
@torch.compile
def f(x: torch.Tensor):
try:
with torch.autocast(device_type="cpu", dtype=None):
Boom.apply(x)
except NotImplementedError:
return x + 1
inp = torch.ones(3)
out = f(inp)
self.assertTrue(torch.equal(out, inp + 1))
if __name__ == "__main__":
unittest.main()
Root cause
- AutocastModeVariable.exit() returns raw Python None rather than a VariableTracker
- WithExitFunctionVariable.call_function() forwards that return value unchanged
- Dynamo assumes every stack element is a VariableTracker, so it calls is_python_constant() on the raw None, leading to the AttributeError
Code below
....
class WithExitFunctionVariable(VariableTracker):
....
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
assert not kwargs
return self.ctx.exit(tx, *args)
class AutocastModeVariable(ContextWrappingVariable):
....
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup_assert()
tx.output.create_node(
"call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
)
# return None value
Fix make every ContextWrappingVariable.exit() return a ConstantVariable wrapper, exactly as other context-manager variables already do.
....
class WithExitFunctionVariable(VariableTracker):
....
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
assert not kwargs
return self.ctx.exit(tx, *args)
class AutocastModeVariable(ContextWrappingVariable):
....
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup_assert()
tx.output.create_node(
"call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
)
# NEW: wrap the None return value
return variables.ConstantVariable.create(None)
Versions
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.5
CMake version: version 3.31.4
Libc version: glibc-2.35
Python version: 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-134-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz
CPU family: 6
Model: 85
Thread(s) per core: 1
Core(s) per socket: 6
Socket(s): 2
Stepping: 0
BogoMIPS: 5187.81
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xsaves arat pku ospke md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 384 KiB (12 instances)
L1i cache: 384 KiB (12 instances)
L2 cache: 12 MiB (12 instances)
L3 cache: 38.5 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX flush not necessary, SMT disabled
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] optree==0.14.0
[pip3] torch==2.7.0
[pip3] torch-debug==2.7.0
[pip3] torch_tb_profiler==0.4.0
[pip3] torchvision==0.21.0
[pip3] triton==3.1.0
[conda] Could not collect
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames