8000 [dynamo] `aot_eager` can't process `try...except` when meeting `AttributeError` · Issue #153605 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] aot_eager can't process try...except when meeting AttributeError #153605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
shaoyuyoung opened this issue May 15, 2025 · 0 comments
Open
Assignees
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@shaoyuyoung
Copy link
Contributor
shaoyuyoung commented May 15, 2025

🐛 Describe the bug

symptom: I know this is a very strange way of usage. When using return xxx which means an SyntaxError, dynamo can correctly enter the except block. However, when using return F.xxx, which means an AttributeError, dynamo throw this error and can't enter the except block.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config

config.fallback_random = True
torch.set_grad_enabled(False)


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        try:
            return F.xxx # if `return xxx`, dynamo passes the check
        except Exception as e:
            return x


model = Model()


x = torch.randn(1)

inputs = [x]


def run_test(model, inputs, backend):
    torch.manual_seed(0)
    if backend != "eager":
        model = torch.compile(model, backend=backend)
    try:
        output = model(*inputs)
        print(f"succeed on {backend}")
    except Exception as e:
        print(e)


run_test(model, inputs, 'eager')
run_test(model, inputs, 'aot_eager')

Error logs

eager

succeed on eager

aot_eager: return F.xxx

AttributeError: module 'torch.nn.functional' has no attribute 'xxx'

aot_eager: return xxx

succeed on aot_eager

Versions

nightly20250515

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

@williamwen42 williamwen42 self-assigned this May 15, 2025
@williamwen42 williamwen42 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 15, 2025
williamwen42 added a commit that referenced this issue May 16, 2025
…ors"


Fixes #153605

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants
0