8000 eval should handle (unhinted: (s77 > 3) | (u0 > 200)) when s77 has hint =5 · Issue #153227 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
eval should handle (unhinted: (s77 > 3) | (u0 > 200)) when s77 has hint =5 #153227
Open
@laithsakka

Description

@laithsakka
torch._dynamo.config.capture_scalar_outputs = True
@torch.compile(dynamic=True, fullgraph=True)
def func(x,y):
        if sym_or(x.size()[0]>3, y.item()>200):
                return x*100
        else:
                return x*200

func(torch.tensor([1,2,3,4,5]), torch.tensor([1]))

cc @chauhang @penguinwu @ezyang @bobrenjc93 @pianpwk

notes:

  1. this work
torch._dynamo.config.capture_scalar_outputs = True
@torch.compile(dynamic=True, fullgraph=True)
def func(x,y):
        d = x.size()[0]>3 or y.item()>200
        if d:
                return x*100
        else:
                return x*200

func(torch.tensor([1,2,3,4,5]), torch.tensor([1]))
  1. this also does not work.

torch._dynamo.config.capture_scalar_outputs = True
@torch.compile(dynamic=True, fullgraph=True)
def func(x,y):
        d = sym_or(x.size()[0]>3 , y.item()>200)
        if d:
                return x*100
        else:
                return x*200

func(torch.tensor([1,2,3,4,5]), torch.tensor([1]))

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0