8000 Fix flaky test in test_custom_ops (#152484) · pytorch/pytorch@0d0058d · GitHub
[go: up one dir, main page]

Skip to content

Commit 0d0058d

Browse files
angelayipytorchmergebot
authored andcommitted
Fix flaky test in test_custom_ops (#152484)
Hopefully fixes #151301, #151281 by making the ops have different names Pull Request resolved: #152484 Approved by: https://github.com/zou3519
1 parent 80af98c commit 0d0058d

File tree

1 file changed

+25
-29
lines changed

1 file changed

+25
-29
lines changed

test/test_custom_ops.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4473,9 +4473,9 @@ def test_mixed_types(self):
44734473

44744474

44754475
class TestOpProfiles(TestCase):
4476-
def get_sample_op_profile(self) -> dict[str, set[OpProfile]]:
4476+
def get_sample_op_profile(self, opname) -> dict[str, set[OpProfile]]:
44774477
return {
4478-
"mylib.foo.default": {
4478+
opname: {
44794479
OpProfile(
44804480
args_profile=(
44814481
TensorMetadata(
@@ -4508,46 +4508,46 @@ def test_fake_registration(self):
45084508
t1 = fm.from_tensor(torch.ones(3, 3))
45094509
t2 = fm.from_tensor(torch.ones(3, 3))
45104510

4511-
op_profiles = self.get_sample_op_profile()
4511+
op_profiles = self.get_sample_op_profile("mylib.foo2.default")
45124512

45134513
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
45144514
torch.library.define(
4515-
"mylib::foo",
4515+
"mylib::foo2",
45164516
"(Tensor a, Tensor b) -> Tensor",
45174517
tags=torch.Tag.pt2_compliant_tag,
45184518
lib=lib,
45194519
)
45204520

4521-
@torch.library.impl("mylib::foo", "cpu", lib=lib)
4521+
@torch.library.impl("mylib::foo2", "cpu", lib=lib)
45224522
def foo_impl(a, b):
45234523
return a + b
45244524

45254525
with (
45264526
self.assertRaisesRegex(
45274527
torch._subclasses.fake_tensor.UnsupportedOperatorException,
4528-
"mylib.foo.default",
4528+
"mylib.foo2.default",
45294529
),
45304530
fm,
45314531
):
4532-
torch.ops.mylib.foo(t1, t2)
4532+
torch.ops.mylib.foo2(t1, t2)
45334533

45344534
with (
45354535
torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles),
45364536
fm,
45374537
):
4538-
torch.ops.mylib.foo(t1, t2)
4538+
torch.ops.mylib.foo2(t1, t2)
45394539

4540-
with self.assertRaisesRegex(MissingOpProfile, "mylib::foo"):
4541-
torch.ops.mylib.foo(torch.ones(3, 3, 3), torch.ones(3, 3, 3))
4540+
with self.assertRaisesRegex(MissingOpProfile, "mylib::foo2"):
4541+
torch.ops.mylib.foo2(torch.ones(3, 3, 3), torch.ones(3, 3, 3))
45424542

45434543
with (
45444544
self.assertRaisesRegex(
45454545
torch._subclasses.fake_tensor.UnsupportedOperatorException,
4546-
"mylib.foo.default",
4546+
"mylib.foo2.default",
45474547
),
45484548
fm,
45494549
):
4550-
torch.ops.mylib.foo(t1, t2)
4550+
torch.ops.mylib.foo2(t1, t2)
45514551

45524552
def test_duplicate_registration_impl(self):
45534553
fm = torch._subclasses.FakeTensorMode(
@@ -4556,33 +4556,33 @@ def test_duplicate_registration_impl(self):
45564556
t1 = fm.from_tensor(torch.ones(3, 3))
45574557
t2 = fm.from_tensor(torch.ones(3, 3))
45584558

4559-
op_profiles = self.get_sample_op_profile()
4559+
op_profiles = self.get_sample_op_profile("mylib.foo3.default")
45604560

45614561
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
45624562
torch.library.define(
4563-
"mylib::foo",
4563+
"mylib::foo3",
45644564
"(Tensor a, Tensor b) -> Tensor",
45654565
tags=torch.Tag.pt2_compliant_tag,
45664566
lib=lib,
45674567
)
45684568

4569-
@torch.library.impl("mylib::foo", "cpu", lib=lib)
4570-
def foo_impl(a, b):
4569+
@torch.library.impl("mylib::foo3", "cpu", lib=lib)
4570+
def foo3_impl(a, b):
45714571
return a + b
45724572

4573-
@torch.library.register_fake("mylib::foo", lib=lib)
4574-
def foo_impl_fake(a, b):
4573+
@torch.library.register_fake("mylib::foo3", lib=lib)
4574+
def foo3_impl_fake(a, b):
45754575
return (a + b).to(dtype=torch.bfloat16)
45764576

45774577
with fm:
4578-
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.bfloat16)
4578+
self.assertEqual(torch.ops.mylib.foo3(t1, t2).dtype, torch.bfloat16)
45794579

45804580
with torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles):
45814581
with fm:
4582-
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.float32)
4582+
self.assertEqual(torch.ops.mylib.foo3(t1, t2).dtype, torch.float32)
45834583

45844584
with fm:
4585-
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.bfloat16)
4585+
self.assertEqual(torch.ops.mylib.foo3(t1, t2).dtype, torch.bfloat16)
45864586

45874587
def test_duplicate_registration_custom_op(self):
45884588
fm = torch._subclasses.FakeTensorMode(
@@ -4591,7 +4591,7 @@ def test_duplicate_registration_custom_op(self):
45914591
t1 = fm.from_tensor(torch.ones(3, 3))
45924592
t2 = fm.from_tensor(torch.ones(3, 3))
45934593

4594-
op_profiles = self.get_sample_op_profile()
4594+
op_profiles = self.get_sample_op_profile("mylib.foo1.default")
45954595

45964596
@torch.library.custom_op("mylib::foo1", mutates_args=())
45974597
def foo_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
@@ -4604,10 +4604,6 @@ def foo_impl_fake(a, b):
46044604
with fm:
46054605
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.bfloat16)
46064606

4607-
op_profiles = {
4608-
"mylib.foo1.default": self.get_sample_op_profile()["mylib.foo.default"]
4609-
}
4610-
46114607
with torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles):
46124608
with fm:
46134609
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.float32)
@@ -4616,14 +4612,14 @@ def foo_impl_fake(a, b):
46164612
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.bfloat16)
46174613

46184614
def test_yaml(self):
4619-
op_profiles = self.get_sample_op_profile()
4615+
op_profiles = self.get_sample_op_profile("mylib.foo.default")
46204616
yaml_str = generate_yaml_from_profiles(op_profiles)
46214617
loaded = read_profiles_from_yaml(yaml_str)
46224618
self.assertEqual(op_profiles, loaded)
46234619

46244620
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
46254621
def test_save_to_file(self):
4626-
op_profile = self.get_sample_op_profile()
4622+
op_profile = self.get_sample_op_profile("mylib.foo.default")
46274623

46284624
# Saving with buffer
46294625
buffer = io.BytesIO()
@@ -4647,7 +4643,7 @@ def test_save_to_file(self):
46474643
self.assertEqual(op_profile, loaded)
46484644

46494645
def test_version(self):
4650-
op_profiles = self.get_sample_op_profile()
4646+
op_profiles = self.get_sample_op_profile("mylib.foo.default")
46514647
yaml_str = generate_yaml_from_profiles(op_profiles)
46524648
loaded = yaml.safe_load(yaml_str)
46534649
loaded["torch_version"] = "2.7"

0 commit comments

Comments
 (0)
0