8000 [inductor] [cuda] [fake tensor] `torch.ones(x.size(0))` becomes a fake tensor for `torch.diagonal_scatter` · Issue #151670 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[inductor] [cuda] [fake tensor] torch.ones(x.size(0)) becomes a fake tensor for torch.diagonal_scatter #151670
@shaoyuyoung

Description

@shaoyuyoung

🐛 Describe the bug

symptom: torch.diagonal_scatter throws fake tensor error, misaligning with eager.
device: only cuda
repro

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config

config.fallback_random = True
torch.set_grad_enabled(False)
import os

os.environ['TORCHDYNAMO_VERBOSE'] = '1'


class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, x):
        y = torch.ones(x.size(0))
        x = torch.diagonal_scatter(x, y)
        return x


model = Model()

x = torch.rand(1, 2)

inputs = [x]


def run_test(model, inputs, device, backend):
    torch.manual_seed(0)
    model = model.to(device)
    inputs = [x.to(device) for x in inputs]
    if backend != "eager":
        model = torch.compile(model, backend=backend)
    try:
        output = model(*inputs)
        print(f"succeed on {backend}")
    except Exception as e:
        print(e)


run_test(model, inputs, 'cuda', 'eager')
run_test(model, inputs, 'cuda', 'inductor')

Error logs

succeed on eager
Dynamo failed to run FX node with fake tensors: call_function <built-in method diagonal_scatter of type object at 0x7f4ca6e5a2c0>(*(FakeTensor(..., device='cuda:0', size=(1, 2)), FakeTensor(..., size=(1,))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.diagonal_scatter.default, found two different devices cuda:0, cpu')

Versions

nightly 20250414

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @eellison @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @amjames @bdhirsh

Metadata

Metadata

Assignees

Labels

high prioritymodule: dynamomodule: fakeTensormodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0