8000 Improve reasoning for size oblivious equations involving min or max() · Issue #125914 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Improve reasoning for size oblivious equations involving min or max() #125914
@ezyang

Description

@ezyang

🐛 Describe the bug

Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7172183699576137/

This program fails to compile:

@torch.compile(fullgraph=True, backend="eager")
def cf(x):
    u0, u1 = x.tolist()
    torch._check_is_size(u0)
    torch._check_is_size(u1)
    torch._check(u0 + u1 == 20)
    if guard_size_oblivious(torch.sym_max(1, u0 + u1) == 20):
        return torch.tensor(True)
    else:
        return torch.tensor(False)

@run_test
def test_symmax():
    assert cf(torch.tensor([10, 10])).item()

Actually, we should be able to make the inference here, because u0 and u1 are size-like in a size oblivious, so we assume they are >= 2, which means that the Max should evaporate, but we are unable to do this.

cc @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bobrenjc93 @bdhirsh @anijain2305 @lezcano

Versions

main

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0