8000 SIGSEGV due to insufficient return value checking for PyFrame_GetLocals · Issue #148273 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

SIGSEGV due to insufficient return value checking for PyFrame_GetLocals #148273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
thomasdullien opened this issue Mar 1, 2025 · 6 comments
Open
Labels
high priority module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: python frontend For issues relating to PyTorch's Python frontend needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@thomasdullien
Copy link
thomasdullien commented Mar 1, 2025

🐛 Describe the bug

I'm getting a SIGSEGV when running some Torch code locally. It appears to be a null pointer dereference caused by insufficient return value checking of PyFrame_GetLocals (which, starting from more recent Python versions, can in theory return NULL -- but all the code calling it blindly assumes it'll return a valid pointer, and happily dereferences it).

Below is the GDB trace:

Starting program: /home/thomasdullien/python-env/pytorch/bin/python3 ./experiments2.py
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
[New Thread 0x7fffd45ff6c0 (LWP 23487)]
[New Thread 0x7fffd3dfe6c0 (LWP 23488)]
[New Thread 0x7fffd35fd6c0 (LWP 23489)]
[New Thread 0x7fffd2dfc6c0 (LWP 23490)]
[New Thread 0x7fffd25fb6c0 (LWP 23491)]
[New Thread 0x7fffd1dfa6c0 (LWP 23492)]
[New Thread 0x7fffd15f96c0 (LWP 23493)]
[New Thread 0x7fffd0df86c0 (LWP 23494)]
[New Thread 0x7fffd05f76c0 (LWP 23495)]
[New Thread 0x7fffcfdf66c0 (LWP 23496)]
[New Thread 0x7fffcf5f56c0 (LWP 23497)]
[New Thread 0x7fffcedf46c0 (LWP 23498)]
[New Thread 0x7fffce5f36c0 (LWP 23499)]
[New Thread 0x7fffcddf26c0 (LWP 23500)]
[New Thread 0x7fffcd5f16c0 (LWP 23501)]
[New Thread 0x7ffef8d5e6c0 (LWP 23504)]
Quadro P2200
[New Thread 0x7ffef2fff6c0 (LWP 23505)]
[New Thread 0x7ffef27fe6c0 (LWP 23506)]
Epoch 1/5000
[New Thread 0x7ffeddbff6c0 (LWP 23508)]
[New Thread 0x7ffed9fff6c0 (LWP 23509)]
[New Thread 0x7ffed97fe6c0 (LWP 23510)]

Thread 1 "python3" received signal SIGSEGV, Segmentation fault.
0x00007fffc42bf856 in torch::profiler::impl::(anonymous namespace)::PythonTracer::recordPyCall(torch::profiler::impl::(anonymous namespace)::ThreadLocalResults&, _frame*, bool) () from /home/thomasdullien/python-env/pytorch/lib/python3.13/site-packages/torch/lib/libtorch_python.so
(gdb) x/20i $rip-0x20
   0x7fffc42bf836 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1206>:	add    %al,(%rax)
   0x7fffc42bf838 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1208>:	call   0x7fffc3ca30b0 <PyFrame_GetLocals@plt>
   0x7fffc42bf83d <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1213>:	mov    %rax,%rdi
   0x7fffc42bf840 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1216>:	lea    0x68005e(%rip),%rsi        # 0x7fffc493f8a5
   0x7fffc42bf847 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1223>:	mov    %rax,0x30(%rsp)
   0x7fffc42bf84c <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1228>:	call   0x7fffc3cb0cf0 <PyDict_GetItemString@plt>
   0x7fffc42bf851 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1233>:	mov    %rax,0x38(%rsp)
=> 0x7fffc42bf856 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1238>:	mov    (%rax),%edx
   0x7fffc42bf858 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1240>:	add    $0x1,%edx
   0x7fffc42bf85b <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1243>:	je     0x7fffc42bf85f <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1247>
   0x7fffc42bf85d <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1245>:	mov    %edx,(%rax)
   0x7fffc42bf85f <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1247>:	mov    %r13,%rdi
   0x7fffc42bf862 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1250>:	call   0x7fffc3cb1840 <PyFrame_GetBack@plt>
   0x7fffc42bf867 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1255>:	mov    %rax,0x60(%rsp)
   0x7fffc42bf86c <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1260>:	mov    %rax,%rsi
   0x7fffc42bf86f <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1263>:	test   %rax,%rax
   0x7fffc42bf872 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1266>:	je     0x7fffc42bfbd0 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+2128>
   0x7fffc42bf878 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1272>:	mov    0x38(%rsp),%rax
   0x7fffc42bf87d <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1277>:	lea    0x98(%rsp),%rdi
   0x7fffc42bf885 <_ZN5torch8profiler4impl12_GLOBAL__N_112PythonTracer12recordPyCallERNS2_18ThreadLocalResultsEP6_frameb+1285>:	mov    %rax,0x90(%rsp)

Versions

Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux trixie/sid (x86_64)
GCC version: (Debian 14.2.0-16) 14.2.0
Clang version: 19.1.7 (1+b1)
CMake version: version 3.31.5
Libc version: glibc-2.40

Python version: 3.13.2 (main, Feb  5 2025, 01:23:35) [GCC 14.2.0] (64-bit runtime)
Python platform: Linux-6.12.12-amd64-x86_64-with-glibc2.40
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Quadro P2200
Nvidia driver version: 535.216.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        39 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               16
On-line CPU(s) list:                  0-15
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) E-2286M  CPU @ 2.40GHz
CPU family:                           6
Model:                                158
Thread(s) per core:                   2
Core(s) per socket:                   8
Socket(s):                            1
Stepping:                             13
CPU(s) scaling MHz:                   58%
CPU max MHz:                          5000.0000
CPU min MHz:                          800.0000
BogoMIPS:                             4800.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust sgx bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp vnmi sgx_lc md_clear flush_l1d arch_capabilities
Virtualization:                       VT-x
L1d cache:                            256 KiB (8 instances)
L1i cache:                            256 KiB (8 instances)
L2 cache:                             2 MiB (8 instances)
L3 cache:                             16 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-15
Vulnerability Gather data sampling:   Mitigation; Microcode
Vulnerability Itlb multihit:          KVM: Mitigation: Split huge pages
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Mitigation; Microcode
Vulnerability Tsx async abort:        Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] numpy==2.2.3
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.6.0
[pip3] triton==3.2.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @albanD

@thomasdullien
Copy link
Author

Correction: Staring at the assembly a bit more, the problem seems that the lookup of "self" in a python locals dictionary returns null, e.g. the code assumes that self is never null?

@janeyx99
Copy link
Contributor
janeyx99 commented Mar 3, 2025

@thomasdullien Are you able to provide a runnable snippet to repro this error or give more insight into the python code that was called to lead to this?

@janeyx99 janeyx99 added high priority module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: python frontend For issues relating to PyTorch's Python frontend labels Mar 3, 2025
@malfet malfet added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label Mar 3, 2025
@janeyx99 janeyx99 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed high priority triage review labels Mar 3, 2025
@janeyx99
Copy link
Contributor
janeyx99 commented Mar 3, 2025

Will readd highpri when a repro is available!

@albanD
Copy link
Collaborator
albanD commented Apr 9, 2025

@thomasdullien would you be able to provide a stack trace by any chance?

@rennsax
Copy link
rennsax commented May 14, 2025

Same issue here. I also get a SIGSEGV when profiling my PyTorch code. I guess these lines of code cause this problem:

auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
Py_INCREF(self.get());

I'm running PyTorch 2.7.0 with Python 3.13.3.

@rennsax
Copy link
rennsax commented May 14, 2025

I can reproduce the issue with this simple profiling script:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.profiler import tensorboard_trace_handler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset
x = torch.linspace(0, 2 * 3.1415926, steps=1000).unsqueeze(1)  # [1000,1]
y = torch.sin(x)
dataset = torch.utils.data.TensorDataset(x, y)
trainloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

# Model
model = nn.Linear(1, 1).to(device)

# Optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

with torch.profiler.profile(
    schedule=torch.profiler.schedule(
        wait=2,
        warmup=2,
        active=6,
        repeat=1),
    on_trace_ready=tensorboard_trace_handler('./log/profiler'),
    with_stack=True
) as profiler:
    for step, data in enumerate(trainloader):
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        profiler.step()
        print(f"step:{step}, loss:{loss.item():.4f}")

        if step >= 15:
            break

print("Profiling completed. Run 'tensorboard --logdir=./log/profiler' to view.")

With Python 3.12.10, this script runs fine. With Python 3.13.3, SIGSEGV is received.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: python frontend For issues relating to PyTorch's Python frontend needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0