-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Open
Labels
high prioritymodule: aotdispatchumbrella label for AOTAutograd issuesumbrella label for AOTAutograd issuesmodule: inductormodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
symptom: It's a very interesting edge case. When the range of torch.clamp
is set to (-0.5, 0.5), given an initial int64
input, it can be implicitly converted into f32
in eager, but inductor loses this mechanism and still outputs int64
, subsequently resulting silent incorrectness.
device backend: both CPP and triton
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config
config.fallback_random = True
torch.set_grad_enabled(False)
torch.manual_seed(0)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.clamp(x, min=-0.5, max=0.5)
return x
model = Model()
x = torch.tensor(1)
print('input:')
print(x)
print(x.dtype)
inputs = [x]
def run_test(model, inputs, device, backend):
torch.manual_seed(0)
model = model.to(device)
inputs = [x.to(device) for x in inputs]
if backend != "eager":
model = torch.compile(model, backend=backend)
torch.manual_seed(0)
output = model(*inputs)
return output
device = 'cpu'
output = run_test(model, inputs, device, 'eager')
c_output = run_test(model, inputs, device, 'aot_eager_decomp_partition')
print("eager output:")
print(output)
print(output.dtype)
print("inductor output:")
print(c_output)
print(c_output.dtype)
Error logs
input:
tensor(1)
torch.int64
eager output:
tensor(0.5000)
torch.float32
inductor output:
tensor(0)
torch.int64
Versions
nightly 20250414
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @muchulee8 @amjames @aakhundov @bdhirsh
Metadata
Metadata
Assignees
Labels
high prioritymodule: aotdispatchumbrella label for AOTAutograd issuesumbrella label for AOTAutograd issuesmodule: inductormodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module