8000 Fix flaky test_inductor_multiple_specializations (#159264) · pytorch/pytorch@6de2413 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6de2413

Browse files
janselpytorchmergebot
authored andcommitted
Fix flaky test_inductor_multiple_specializations (#159264)
Summary: This test was using do_bench, so it was flaky performance is non-deterministic. Test Plan: buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:compile_subprocess -- --exact 'caffe2/test/inductor:compile_subprocess - test_inductor_multiple_specializations_cuda (caffe2.test.inductor.test_compile_subprocess.GPUTests)' --run-disabled Rollback Plan: Differential Revision: D79098692 Pull Request resolved: #159264 Approved by: https://github.com/jingsh
1 parent 27ae720 commit 6de2413

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

test/inductor/test_torchinductor.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10553,8 +10553,6 @@ def f(x):
1055310553
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
1055410554
)
1055510555
def test_inductor_multiple_specializations(self):
10556-
from triton.testing import do_bench
10557-
1055810556
@torch.compile(
1055910557
options={
1056010558
"max_autotune": True,
@@ -10569,7 +10567,7 @@ def inductor_matmul(a, b):
1056910567
m = 16
1057010568
k = 1280
1057110569
dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
10572-
dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
10570+
dynamic_specialized_a = dynamic_a.clone()
1057310571
b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16)
1057410572
torch._dynamo.decorators.mark_dynamic(
1057510573
dynamic_a,
@@ -10584,12 +10582,10 @@ def inductor_matmul(a, b):
1058410582
b,
1058510583
1,
1058610584
)
10587-
dynamic = do_bench(lambda: inductor_matmul(dynamic_a, b))
10585+
dynamic = inductor_matmul(dynamic_a, b)
1058810586
torch._dynamo.reset()
10589-
dynamic_specialized = do_bench(
10590-
lambda: inductor_matmul(dynamic_specialized_a, b)
10591-
)
10592-
self.assertGreaterEqual(dynamic, dynamic_specialized)
10587+
dynamic_specialized = inductor_matmul(dynamic_specialized_a, b)
10588+
self.assertEqual(dynamic, dynamic_specialized)
1059310589

1059410590
@requires_gpu()
1059510591
def test_stride_preservation_with_stride_modifying_fx_pass(self):

0 commit comments

Comments
 (0)
0