-
Notifications
You must be signed in to change notification settings - Fork 24.2k
SIGSEGV due to insufficient return value checking for PyFrame_GetLocals #148273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Correction: Staring at the assembly a bit more, the problem seems that the lookup of "self" in a python locals dictionary returns null, e.g. the code assumes that self is never null? |
@thomasdullien Are you able to provide a runnable snippet to repro this error or give more insight into the python code that was called to lead to this? |
Will readd highpri when a repro is available! |
@thomasdullien would you be able to provide a stack trace by any chance? |
Same issue here. I also get a SIGSEGV when profiling my PyTorch code. I guess these lines of code cause this problem: pytorch/torch/csrc/autograd/profiler_python.cpp Lines 885 to 886 in f7798d8
I'm running PyTorch 2.7.0 with Python 3.13.3. |
I can reproduce the issue with this simple profiling script: import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.profiler import tensorboard_trace_handler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Dataset
x = torch.linspace(0, 2 * 3.1415926, steps=1000).unsqueeze(1) # [1000,1]
y = torch.sin(x)
dataset = torch.utils.data.TensorDataset(x, y)
trainloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
# Model
model = nn.Linear(1, 1).to(device)
# Optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
with torch.profiler.profile(
schedule=torch.profiler.schedule(
wait=2,
warmup=2,
active=6,
repeat=1),
on_trace_ready=tensorboard_trace_handler('./log/profiler'),
with_stack=True
) as profiler:
for step, data in enumerate(trainloader):
inputs, labels = data[0].to(device), data[1].to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
profiler.step()
print(f"step:{step}, loss:{
7A43
span>loss.item():.4f}")
if step >= 15:
break
print("Profiling completed. Run 'tensorboard --logdir=./log/profiler' to view.") With Python 3.12.10, this script runs fine. With Python 3.13.3, SIGSEGV is received. |
🐛 Describe the bug
I'm getting a SIGSEGV when running some Torch code locally. It appears to be a null pointer dereference caused by insufficient return value checking of PyFrame_GetLocals (which, starting from more recent Python versions, can in theory return NULL -- but all the code calling it blindly assumes it'll return a valid pointer, and happily dereferences it).
Below is the GDB trace:
Versions
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @albanD
The text was updated successfully, but these errors were encountered: