Open
Description
torch._dynamo.config.capture_scalar_outputs = True
@torch.compile(dynamic=True, fullgraph=True)
def func(x,y):
if sym_or(x.size()[0]>3, y.item()>200):
return x*100
else:
return x*200
func(torch.tensor([1,2,3,4,5]), torch.tensor([1]))
cc @chauhang @penguinwu @ezyang @bobrenjc93 @pianpwk
notes:
- this work
torch._dynamo.config.capture_scalar_outputs = True
@torch.compile(dynamic=True, fullgraph=True)
def func(x,y):
d = x.size()[0]>3 or y.item()>200
if d:
return x*100
else:
return x*200
func(torch.tensor([1,2,3,4,5]), torch.tensor([1]))
- this also does not work.
torch._dynamo.config.capture_scalar_outputs = True
@torch.compile(dynamic=True, fullgraph=True)
def func(x,y):
d = sym_or(x.size()[0]>3 , y.item()>200)
if d:
return x*100
else:
return x*200
func(torch.tensor([1,2,3,4,5]), torch.tensor([1]))