8000 Removed custom compare function from test · pytorch/pytorch@0f9df95 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0f9df95

Browse files
committed
Removed custom compare function from test
1 parent 1790bfd commit 0f9df95

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

test/inductor/test_torchinductor.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8349,9 +8349,23 @@ def fn(a, b):
83498349

83508350
compiled = torch.compile(fn)
83518351

8352-
def compare(a, b):
8353-
out_eager = fn(a, b)
8354-
out_inductor = compiled(a, b)
8352+
a = torch.randn([8, 8])
8353+
b = torch.randn([2, 8])
8354+
8355+
for dtype in (torch.int8, torch.float16, torch.int64, torch.bool):
8356+
out_eager = fn(a.to(dtype), b)
8357+
out_inductor = compiled(a.to(dtype), b)
8358+
self.assertEqual(
8359+
out_inductor.dtype,
8360+
out_eager.dtype,
8361+
f"Expected dtype {out_eager.dtype}, but got {out_inductor.dtype}",
8362+
)
8363+
self.assertTrue(
8364+
torch.allclose(out_inductor, out_eager),
8365+
f"Allclose failed for dtype {a.dtype}",
8366+
)
8367+
out_eager = fn(a, b.to(dtype))
8368+
out_inductor = compiled(a, b.to(dtype))
83558369
self.assertEqual(
83568370
out_inductor.dtype,
83578371
out_eager.dtype,
@@ -8361,13 +8375,6 @@ def compare(a, b):
83618375
torch.allclose(out_inductor, out_eager),
83628376
f"Allclose failed for dtype {a.dtype}",
83638377
)
8364-
8365-
a = torch.randn([8, 8])
8366-
b = torch.randn([2, 8])
8367-
8368-
for dtype in (torch.int8, torch.float16, torch.int64, torch.bool):
8369-
compare(a.to(dtype), b)
8370-
compare(a, b.to(dtype))
83718378

83728379
@with_tf32_off
83738380
def test_slice_scatter_reinplace(self):

0 commit comments

Comments
 (0)
0