10000 [dynamo] fix prim lowering validation logic for dynamic shape args by jon-chuang · Pull Request #111208 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] fix prim lowering validation logic for dynamic shape args #111208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix
  • Loading branch information
jon-chuang committed Oct 13, 2023
commit d7e204674dd03bff134f1a58a1b2bbffd10e9681
27 changes: 27 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -3516,6 +3516,33 @@ def fn(x):
x = torch.rand(4)
self.assertTrue(same(fn(x), opt_fn(x)))

def test_add_sub_alpha_out(self):
inp = torch.randn(2, 3, 4)
other = 1
alpha = 2
for op in [torch.add, torch.sub]:
out = torch.zeros(2, 3, 4)
compile_out = torch.zeros(2, 3, 4)
op(inp, other, alpha=alpha, out=out)
compiled_fn = torch.compile(op, dynamic=True)
compiled_fn(inp, other, alpha=alpha, out=compile_out)
self.assertTrue(same(out, compile_out))

def test_addr_alpha_beta_out(self):
inp = torch.randn(2, 3)
vec1 = torch.randn(2)
vec2 = torch.randn(3)
alpha = 2
beta = 5

out = torch.zeros(2, 3)
compile_out = torch.zeros(2, 3)

torch.addr(inp, vec1, vec2, alpha=alpha, beta=beta, out=out)
compiled_fn = torch.compile(torch.addr, dynamic=True)
compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out)
self.assertTrue(same(out, compile_out))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
34 changes: 23 additions & 11 deletions torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool:
return True


def _maybe_get_pytype(t):
if t is torch.SymFloat:
return float
elif t is torch.SymInt:
return int
elif t is torch.SymBool:
return bool
else:
return t


# TODO: look at using torch.testing.assert_close instead with an option
# to just compare metadata
def compare_tensor_meta(
Expand Down Expand Up @@ -1003,9 +1014,12 @@ def get_higher_type(a: type, b: type) -> type:

The types are ordered bool -> int -> float -> complex.
"""
a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)
# Type checking
assert a in _ordered_types
assert b in _ordered_types
if a not in _ordered_types or b not in _ordered_types:
raise RuntimeError(
f"Expected builtin numeric types, found {type(a)}, {type(b)}"
)

if a is b:
return a
Expand Down Expand Up @@ -1104,17 +1118,15 @@ def is_weakly_lesser_type(a: type, b: type) -> bool:

The comparison is determined by the following type ordering: bool, int, float, complex.
"""
ordered_types = (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate global defined above

bool,
int,
float,
complex,
)

assert a in ordered_types
assert b in ordered_types
a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)

if a not in _ordered_types or b not in _ordered_types:
raise RuntimeError(
f"Expected builtin numeric types, found {type(a)}, {type(b)}"
)

for typ in ordered_types:
for typ in _ordered_types:
if a == typ:
return True
if b == typ:
Expand Down
0