@@ -4473,9 +4473,9 @@ def test_mixed_types(self):
44734473
44744474
44754475class TestOpProfiles (TestCase ):
4476- def get_sample_op_profile (self ) -> dict [str , set [OpProfile ]]:
4476 + def get_sample_op_profile (self , opname ) -> dict [str , set [OpProfile ]]:
44774477 return {
4478- "mylib.foo.default" : {
4478+ opname : {
44794479 OpProfile (
44804480 args_profile = (
44814481 TensorMetadata (
@@ -4508,46 +4508,46 @@ def test_fake_registration(self):
45084508 t1 = fm .from_tensor (torch .ones (3 , 3 ))
45094509 t2 = fm .from_tensor (torch .ones (3 , 3 ))
45104510
4511- op_profiles = self .get_sample_op_profile ()
4511+ op_profiles = self .get_sample_op_profile ("mylib.foo2.default" )
45124512
45134513 with torch .library ._scoped_library ("mylib" , "FRAGMENT" ) as lib :
45144514 torch .library .define (
4515- "mylib::foo " ,
4515+ "mylib::foo2 " ,
45164516 "(Tensor a, Tensor b) -> Tensor" ,
45174517 tags = torch .Tag .pt2_compliant_tag ,
45184518 lib = lib ,
45194519 )
45204520
4521- @torch .library .impl ("mylib::foo " , "cpu" , lib = lib )
4521+ @torch .library .impl ("mylib::foo2 " , "cpu" , lib = lib )
45224522 def foo_impl (a , b ):
45234523 return a + b
45244524
45254525 with (
45264526 self .assertRaisesRegex (
45274527 torch ._subclasses .fake_tensor .UnsupportedOperatorException ,
4528- "mylib.foo .default" ,
4528+ "mylib.foo2 .default" ,
45294529 ),
45304530 fm ,
45314531 ):
4532- torch .ops .mylib .foo (t1 , t2 )
4532+ torch .ops .mylib .foo2 (t1 , t2 )
45334533
45344534 with (
45354535 torch ._library .fake_profile .unsafe_generate_fake_kernels (op_profiles ),
45364536 fm ,
45374537 ):
4538- torch .ops .mylib .foo (t1 , t2 )
4538+ torch .ops .mylib .foo2 (t1 , t2 )
45394539
4540- with self .assertRaisesRegex (MissingOpProfile , "mylib::foo " ):
4541- torch .ops .mylib .foo (torch .ones (3 , 3 , 3 ), torch .ones (3 , 3 , 3 ))
4540+ with self .assertRaisesRegex (MissingOpProfile , "mylib::foo2 " ):
4541+ torch .ops .mylib .foo2 (torch .ones (3 , 3 , 3 ), torch .ones (3 , 3 , 3 ))
45424542
45434543 with (
45444544 self .assertRaisesRegex (
45454545 torch ._subclasses .fake_tensor .UnsupportedOperatorException ,
4546- "mylib.foo .default" ,
4546+ "mylib.foo2 .default" ,
45474547 ),
45484548 fm ,
45494549 ):
4550- torch .ops .mylib .foo (t1 , t2 )
4550+ torch .ops .mylib .foo2 (t1 , t2 )
45514551
45524552 def test_duplicate_registration_impl (self ):
45534553 fm = torch ._subclasses .FakeTensorMode (
@@ -4556,33 +4556,33 @@ def test_duplicate_registration_impl(self):
45564556 t1 = fm .from_tensor (torch .ones (3 , 3 ))
45574557 t2 = fm .from_tensor (torch .ones (3 , 3 ))
45584558
4559- op_profiles = self .get_sample_op_profile ()
4559+ op_profiles = self .get_sample_op_profile ("mylib.foo3.default" )
45604560
45614561 with torch .library ._scoped_library ("mylib" , "FRAGMENT" ) as lib :
45624562 torch .library .define (
4563- "mylib::foo " ,
4563+ "mylib::foo3 " ,
45644564 "(Tensor a, Tensor b) -> Tensor" ,
45654565 tags = torch .Tag .pt2_compliant_tag ,
45664566 lib = lib ,
45674567 )
45684568
4569- @torch .library .impl ("mylib::foo " , "cpu" , lib = lib )
4570- def foo_impl (a , b ):
4569+ @torch .library .impl ("mylib::foo3 " , "cpu" , lib = lib )
4570+ def foo3_impl (a , b ):
45714571 return a + b
45724572
4573- @torch .library .register_fake ("mylib::foo " , lib = lib )
4574- def foo_impl_fake (a , b ):
4573+ @torch .library .register_fake ("mylib::foo3 " , lib = lib )
4574+ def foo3_impl_fake (a , b ):
45754575 return (a + b ).to (dtype = torch .bfloat16 )
45764576
45774577 with fm :
4578- self .assertEqual (torch .ops .mylib .foo (t1 , t2 ).dtype , torch .bfloat16 )
4578+ self .assertEqual (torch .ops .mylib .foo3 (t1 , t2 ).dtype , torch .bfloat16 )
45794579
45804580 with torch ._library .fake_profile .unsafe_generate_fake_kernels (op_profiles ):
45814581 with fm :
4582- self .assertEqual (torch .ops .mylib .foo (t1 , t2 ).dtype , torch .float32 )
4582+ self .assertEqual (torch .ops .mylib .foo3 (t1 , t2 ).dtype , torch .float32 )
45834583
45844584 with fm :
4585- self .assertEqual (torch .ops .mylib .foo (t1 , t2 ).dtype , torch .bfloat16 )
4585+ self .assertEqual (torch .ops .mylib .foo3 (t1 , t2 ).dtype , torch .bfloat16 )
45864586
45874587 def test_duplicate_registration_custom_op (self ):
45884588 fm = torch ._subclasses .FakeTensorMode (
@@ -4591,7 +4591,7 @@ def test_duplicate_registration_custom_op(self):
45914591 t1 = fm .from_tensor (torch .ones (3 , 3 ))
45924592 t2 = fm .from_tensor (torch .ones (3 , 3 ))
45934593
4594- op_profiles = self .get_sample_op_profile ()
4594+ op_profiles = self .get_sample_op_profile ("mylib.foo1.default" )
45954595
45964596 @torch .library .custom_op ("mylib::foo1" , mutates_args = ())
45974597 def foo_impl (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
@@ -4604,10 +4604,6 @@ def foo_impl_fake(a, b):
46044604 with fm :
46054605 self .assertEqual (torch .ops .mylib .foo1 (t1 , t2 ).dtype , torch .bfloat16 )
46064606
4607- op_profiles = {
4608- "mylib.foo1.default" : self .get_sample_op_profile ()["mylib.foo.default" ]
4609- }
4610-
46114607 with torch ._library .fake_profile .unsafe_generate_fake_kernels (op_profiles ):
46124608 with fm :
46134609 self .assertEqual (torch .ops .mylib .foo1 (t1 , t2 ).dtype , torch .float32 )
@@ -4616,14 +4612,14 @@ def foo_impl_fake(a, b):
46164612 self .assertEqual (torch .ops .mylib .foo1 (t1 , t2 ).dtype , torch .bfloat16 )
46174613
46184614 def test_yaml (self ):
4619- op_profiles = self .get_sample_op_profile ()
4615+ op_profiles = self .get_sample_op_profile ("mylib.foo.default" )
46204616 yaml_str = generate_yaml_from_profiles (op_profiles )
46214617 loaded = read_profiles_from_yaml (yaml_str )
46224618 self .assertEqual (op_profiles , loaded )
46234619
46244620 @unittest .skipIf (IS_WINDOWS , "Windows not supported for this test" )
46254621 def test_save_to_file (self ):
4626- op_profile = self .get_sample_op_profile ()
4622+ op_profile = self .get_sample_op_profile ("mylib.foo.default" )
46274623
46284624 # Saving with buffer
46294625 buffer = io .BytesIO ()
@@ -4647,7 +4643,7 @@ def test_save_to_file(self):
46474643 self .assertEqual (op_profile , loaded )
46484644
46494645 def test_version (self ):
4650- op_profiles = self .get_sample_op_profile ()
4646+ op_profiles = self .get_sample_op_profile ("mylib.foo.default" )
46514647 yaml_str = generate_yaml_from_profiles (op_profiles )
46524648 loaded = yaml .safe_load (yaml_str )
46534649 loaded ["torch_version" ] = "2.7"
0 commit comments