8000 [dynamo] torch._dynamo crashes on `self.value.__module__` inside SkipFunctionVariable.call_function() (PyTorch 2.7, works 2.6) · Issue #152316 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] torch._dynamo crashes on self.value.__module__ inside SkipFunctionVariable.call_function() (PyTorch 2.7, works 2.6) #152316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
8000
wdziurdz opened this issue Apr 28, 2025 · 5 comments
Assignees
Labels
high priority module: dynamo module: regression It used to work, and now it doesn't needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@wdziurdz
Copy link
Contributor
wdziurdz commented Apr 28, 2025

🐛 Describe the bug

Start cacth after upgrade from 2.6 to 2.7. crash in dynamo . The crash happens because the PyTorch doesn’t check whether the object has a __module__ attribute:

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1754, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1765, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/root/model_garden/PyTorch/examples/gpu_migration/nlp/bert/modeling.py", line 859, in forward
[rank1]:     tmp = (attention_mask == i+1).type(torch.float32).unsqueeze(-1)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/gpu_migration/torch/_tensor.py", line 206, in type
[rank1]:     log_args = locals() if G_LOGGER.module_severity <= G_LOGGER.INFO else None
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__
[rank1]:     return self._torchdynamo_orig_callable(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1213, in __call__
[rank1]:     result = self._inner_convert(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 598, in __call__
[rank1]:     return _compile(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1110, in _compile
[rank1]:     raise InternalTorchDynamoError(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
[rank1]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 97, in wrapper_function
[rank1]:     return function(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
[rank1]:     return _compile_inner(code, one_graph, hooks, transform)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 797, in _compile_inner
[rank1]:     out_code = transform_code_object(code, transform)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1422, in transform_code_object
[rank1]:     transformations(instructions, code_options)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 257, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in transform
[rank1]:     tracer.run()
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3500, in run
[rank1]:     super().run()
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
[rank1]:     while self.step():
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
[rank1]:     self.dispatch_table[inst.opcode](self, inst)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper
[rank1]:     return inner_fn(self, inst)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2266, in CALL_FUNCTION_EX
[rank1]:     self.call_function(fn, argsvars.items, kwargsvars)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function
[rank1]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 926, in call_function
[rank1]:     return super().call_function(tx, args, kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 404, in call_function
[rank1]:     return super().call_function(tx, args, kwargs)
[rank1]:   File 
8000
"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 185, in call_function
[rank1]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1187, in inline_user_function_return
[rank1]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3726, in inline_call
[rank1]:     return tracer.inline_call_()
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3905, in inline_call_
[rank1]:     self.run()
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
[rank1]:     while self.step():
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
[rank1]:     self.dispatch_table[inst.opcode](self, inst)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper
[rank1]:     return inner_fn(self, inst)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2266, in CALL_FUNCTION_EX
[rank1]:     self.call_function(fn, argsvars.items, kwargsvars)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function
[rank1]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 1224, in call_function
[rank1]:     if self.value.__module__ in known_python_builtin_modules:
[rank1]: torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'method_descriptor' object has no attribute '__module__'

This example also shows that not every object has a __module__ attribute, the code below crashes because the method descriptor torch.Tensor.type lacks that attribute:

import torch
tmp = torch.Tensor.type
print(tmp.__module__)

The problem is in SkipFunctionVariable.call_function(), where the code unconditionally accesses self.value.__module__.
Many built-in C descriptors (e.g. method_descriptor) do not define that attribute, so the lookup itself raises an AttributeError. Relevant source code below.

class SkipFunctionVariable(VariableTracker):
    ...
    def call_function(
        self,
        tx: "InstructionTranslator",
        args: "list[VariableTracker]",
        kwargs: "dict[str, VariableTracker]",
    ) -> "VariableTracker":
        ....
            except TypeError:
                known_python_builtin_modules = {"_abc", "_warnings"}
                if self.value.__module__ in known_python_builtin_modules:
                    explanation = (
                        f"Dynamo does not know how to trace the Python builtin "
                        f"`{self.value.__module__}.{qualname}`."
                    )
                    hints = [
                        "If you are attempting to call a logging function (e.g. `_warnings.warn`), "
                        "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.",
                        "Please file an issue on GitHub "
                        "so the PyTorch team can add support for it. ",
                    ]
                elif (
                    self.value.__module__ is not None
                    and self.value.__module__.startswith("optree")
                ):
                    explanation = f"Dynamo cannot trace optree C/C++ function {self.value.__module__}.{qualname}."
                    hints = [
                        " Consider using torch.utils._pytree - "
                        "https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py"
                    ]
                    # also warn on it because most users won't see the graph break message
                    torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
                else:
                    explanation = (
                        f"Dynamo does not know how to trace the builtin `{self.value.__module__}.{qualname}.` "
                        f"This function is either a Python builtin (e.g. _warnings.warn) "
                        f"or a third-party C/C++ Python extension (perhaps created with pybind)."
                    )
                    hints = [
                        "If it is a Python builtin, please file an issue on GitHub "
                        "so the PyTorch team can add support for it and see the next case for a workaround.",
                        "If it is a third-party C/C++ Python extension, please "
                        "either wrap it into a PyTorch-understood custom operator "
                        "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
                        "for more details) or, if it is traceable, use "
                        "`torch.compiler.allow_in_graph`.",
                    ]
                    # also warn on it because most users won't see the graph break message
                    torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
            if qualname == "allow_in_graph":
                explanation = (
                    "Found an allow_in_graph decorator to a function which "
                    "is created inside the parent function that is getting "
                    "compiled. This is not supported for now."
                )
                hints = []
            reason = self.reason if self.reason else "<missing reason>"
            unimplemented_v2(
                gb_type="Attempted to call function marked as skipped",
                context=f"module: {self.value.__module__}, qualname: {qualname}, skip reason: {reason}",
                explanation=explanation,
                hints=hints,
            )

Suggested fix.

- if self.value.__module__ in known_python_builtin_modules:
+ module = getattr(self.value, "__module__", None)
+ if module in known_python_builtin_modules:

All subsequent uses of self.value.__module__ in this block should be replaced by module.

Versions

Collecting environment information...
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-136-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 160
On-line CPU(s) list: 0-159
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8380 CPU @ 2.30GHz
CPU family: 6
Model: 106
Thread(s) per core: 2
Core(s) per socket: 40
Socket(s): 2
Stepping: 6
CPU max MHz: 3400.0000
CPU min MHz: 800.0000
BogoMIPS: 4600.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid fsrm md_clear pconfig flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 3.8 MiB (80 instances)
L1i cache: 2.5 MiB (80 instances)
L2 cache: 100 MiB (80 instances)
L3 cache: 120 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-39,80-119
NUMA node1 CPU(s): 40-79,120-159
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.5.1
[pip3] torch==2.7.0
[pip3] torch_tb_profiler==0.4.0
[pip3] torchaudio==2.7.0a0
[pip3] torchdata==0.11.0
[pip3] torchmetrics==1.7.0
[pip3] torchtext==0.18.0a0
[pip3] torchvision==0.22.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @amjames

@zou3519
Copy link
Contributor
zou3519 commented Apr 28, 2025

@wdziurdz thank you for the issue report. Do you have a script we could run that reproduces the problem?

@wdziurdz
Copy link
Contributor Author

@zou3519 It’s more complicated, I can’t reproduce it with a simple example. I did notice in similar way was fixed qualname, and I’ve submitted fix here: #152320

@atalman atalman added this to the 2.7.1 milestone Apr 28, 2025
@malfet malfet added module: regression It used to work, and now it doesn't needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user labels Apr 28, 2025
@malfet
Copy link
Contributor
malfet commented Apr 28, 2025

IMO we should modify our triage rules to see high-pri oncall:pt2 issues

@jerryzh168 jerryzh168 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 28, 2025
@malfet malfet added oncall: pt2 high priority needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user module: regression It used to work, and now it doesn't and removed high priority needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: regression It used to work, and now it doesn't labels Apr 28, 2025
@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module triage review and removed triage review labels May 5, 2025
@zou3519
Copy link
Contributor
zou3519 commented May 6, 2025

Animesh to figure out cherry-pick to 2.7.1

@StrongerXi
Copy link
Contributor

Looks like #151277 fixed this, just need to cherry-pick that into 2.7.1?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: dynamo module: regression It used to work, and now it doesn't needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants
0