8000 Export + autocast is eating the exception · Issue #153202 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Export + autocast is eating the exception #153202
Open
@tugsbayasgalan

Description

@tugsbayasgalan

🐛 Describe the bug

import torch


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)

    def forward(self, x):
        token_ids = torch.randint(0, 10, (4,), device=x.device)
        embedded = self.embedding(token_ids).sum()
        return self.buffer.sum() + self.param.sum() + x.sum() + embedded


class BarModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mod = Model()

    def forward(self, x):
        if "cuda" in str(x.device):
            mod = self.mod.to(x.device)
            return mod(x)
        else:
            return x.sum()


class BarBar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mod = BarModel()

    def forward(self, x):
        with torch.no_grad(), torch.amp.autocast(device_type="cuda"):
            y = self.mod(x)
        return y


with torch.no_grad():
    ep = torch.export.export(
        BarBar(), (), {"x": torch.randn(4, 4, 4, device="cuda")}, strict=False
    ).module()

    print(ep.graph)
    print(ep(x=torch.randn(4, 4, 4, device="cuda")))

This throws:

UnboundLocalError: cannot access local variable 'y' where it is not associated with a value

If I comment out the torch.amp.autocast(device_type="cuda"), I will see the real exception:

RuntimeError: Unhandled FakeTensor Device Propagation for aten.embedding.default, found two different devices cpu, cuda:0

Versions

main

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @angelayi @suo @ydwu4

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0