Open
Description
🐛 Describe the bug
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
def forward(self, x):
token_ids = torch.randint(0, 10, (4,), device=x.device)
embedded = self.embedding(token_ids).sum()
return self.buffer.sum() + self.param.sum() + x.sum() + embedded
class BarModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = Model()
def forward(self, x):
if "cuda" in str(x.device):
mod = self.mod.to(x.device)
return mod(x)
else:
return x.sum()
class BarBar(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = BarModel()
def forward(self, x):
with torch.no_grad(), torch.amp.autocast(device_type="cuda"):
y = self.mod(x)
return y
with torch.no_grad():
ep = torch.export.export(
BarBar(), (), {"x": torch.randn(4, 4, 4, device="cuda")}, strict=False
).module()
print(ep.graph)
print(ep(x=torch.randn(4, 4, 4, device="cuda")))
This throws:
UnboundLocalError: cannot access local variable 'y' where it is not associated with a value
If I comment out the torch.amp.autocast(device_type="cuda"), I will see the real exception:
RuntimeError: Unhandled FakeTensor Device Propagation for aten.embedding.default, found two different devices cpu, cuda:0
Versions
main
cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @angelayi @suo @ydwu4