8000 Export + autocast is eating the exception · Issue #153202 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Export + autocast is eating the exception #153202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
tugsbayasgalan opened this issue May 8, 2025 · 0 comments
Open

Export + autocast is eating the exception #153202

tugsbayasgalan opened this issue May 8, 2025 · 0 comments

Comments

@tugsbayasgalan
Copy link
Contributor
tugsbayasgalan commented May 8, 2025

🐛 Describe the bug

import torch


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)

    def forward(self, x):
        token_ids = torch.randint(0, 10, (4,), device=x.device)
        embedded = self.embedding(token_ids).sum()
        return self.buffer.sum() + self.param.sum() + x.sum() + embedded


class BarModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mod = Model()

    def forward(self, x):
        if "cuda" in str(x.device):
            mod = self.mod.to(x.device)
            return mod(x)
        else:
            return x.sum()


class BarBar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mod = BarModel()

    def forward(self, x):
        with torch.no_grad(), torch.amp.autocast(device_type="cuda"):
            y = self.mod(x)
        return y


with torch.no_grad():
    ep = torch.export.export(
        BarBar(), (), {"x": torch.randn(4, 4, 4, device="cuda")}, strict=False
    ).module()

    print(ep.graph)
    print(ep(x=torch.randn(4, 4, 4, device="cuda")))

This throws:

UnboundLocalError: cannot access local variable 'y' where it is not associated with a value

If I comment out the torch.amp.autocast(device_type="cuda"), I will see the real exception:

RuntimeError: Unhandled FakeTensor Device Propagation for aten.embedding.default, found two different devices cpu, cuda:0

Versions

main

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @angelayi @suo @ydwu4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants
0