8000 Using int(shape) in export would result in silent specialization · Issue #138853 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Using int(shape) in export would result in silent specialization #138853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
henrylhtsang opened this issue Oct 24, 2024 · 5 comments
Closed

Using int(shape) in export would result in silent specialization #138853

henrylhtsang opened this issue Oct 24, 2024 · 5 comments
Assignees
Labels
oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@henrylhtsang
Copy link
Contributor
henrylhtsang commented Oct 24, 2024

🐛 Describe the bug

Hi team, just reporting this problem. I can bypass it if I replace int with math.trunc.

repro:

class M(torch.nn.Module):
    def forward(self, x):
        ori_size = (
            int(x.shape[-2] / 1),
            int(x.shape[-1] / 1),
        )
        x = F.interpolate(x, size=ori_size, mode="bilinear")
        return x

input1 = (torch.rand(1, 3, 28, 28, device="cuda"),)
input2 = (torch.rand(1, 3, 56, 56, device="cuda"),)
inputs = [input1, input2]
model = M().cuda()

_ = model(*input1)

dynamic_shapes = {
    "x": {2: torch.export.Dim.DYNAMIC, 3: torch.export.Dim.DYNAMIC},
}
ep = torch.export.export(model, input1, dynamic_shapes=dynamic_shapes, strict=False)
path = torch._inductor.aot_compile(ep.module(), input1)
aot_model = torch._export.aot_load(path, device="cuda")
for input in inputs:
    torch.testing.assert_close(aot_model(*input), model(*input))

error:

torch/testing/_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: The values for attribute 'shape' do not match: torch.Size([1, 3, 28, 28]) != torch.Size([1, 3, 56, 56]).

Versions

trunk

cc @ezyang @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@ezyang
Copy link
Contributor
ezyang commented Oct 27, 2024

In non-strict export this is unavoidable, Python's rules for __int__ overload do not permit a non-int return, which mean specialization. If you use strict export this is automatically handled.

@henrylhtsang
Copy link
Contributor Author

In non-strict export this is unavoidable, Python's rules for __int__ overload do not permit a non-int return, which mean specialization. If you use strict export this is automatically handled.

Is there a way to fail loudly or print a warning?

@ezyang
Copy link
Contributor
ezyang commented Oct 29, 2024

@angelayi this is a funny opposite case to draft export where the size guards would have told you about it but they all got ignored :P

@desertfire desertfire added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 29, 2024
@pianpwk pianpwk self-assigned this Apr 22, 2025
@pianpwk
Copy link
Contributor
pianpwk commented Apr 22, 2025

One view of this is that symbolic_shapes is unable to realize that the guard emitted, Eq(TruncToInt(IntTrueDiv(s24, 1)), 28), is only true for a single value of s24 (28), and so this dimension should specialize and raise an error with Dim.DYNAMIC.

Alternatively, If we remove the division op, export raises a specialization error, recognizing that we've specialized.

Or on the other hand, if we change the denominator to be > 1, then this guard is true for more than 1 value, and we technically haven't specialized the symbol, but the problem remains that int() silently introduced this equality guard.

I think these could be improved:

  • try to emit a warning in non-strict when int() is used on a symbolic value, and maybe recommend sym_int
  • try to improve symbolic shapes reasoning to recognize this particular specialization, to raise

@henrylhtsang
Copy link
Contributor Author

@pianpwk Thanks. I think it makes sense if this cannot be "fixed", but making debugging easier would help already

btw noob question: does sym_int work similar to math.trunc in a case like this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants
0