8000 [Dynamo][Trace PyDispatcher] Support calling id function over class (… · pytorch/pytorch@511d0dd · GitHub
[go: up one dir, main page]

Skip to content

Commit 511d0dd

Browse files
yanboliangpytorchmergebot
authored andcommitted
[Dynamo][Trace PyDispatcher] Support calling id function over class (#146269)
Pull Request resolved: #146269 Approved by: https://github.com/anijain2305
1 parent 02fd486 commit 511d0dd

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

test/dynamo/test_misc.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4723,6 +4723,35 @@ def fn_has_breaks(x):
47234723
opt_fn(x)
47244724
self.assertEqual(cnts.frame_count, 2)
47254725

4726+
def test_id_guarded_class(self):
4727+
class MyClass1:
4728+
pass
4729+
4730+
class MyClass2:
4731+
pass
4732+
4733+
def fn(x, y):
4734+
return x + id(y) // 100000
4735+
4736+
cnts = torch._dynamo.testing.CompileCounter()
4737+
compiled_fn = torch.compile(backend=cnts, fullgraph=True)(fn)
4738+
x = torch.randn(3)
4739+
y = MyClass1
4740+
self.assertEqual(fn(x, y), compiled_fn(x, y))
4741+
self.assertEqual(cnts.frame_count, 1)
4742+
4743+
# No recompile if still pass in the original class (MyClass1)
4744+
x = torch.randn(3)
4745+
y = MyClass1
4746+
self.assertEqual(fn(x, y), compiled_fn(x, y))
4747+
self.assertEqual(cnts.frame_count, 1)
4748+
4749+
# Have to recompile if pass in new class (MyClass2)
4750+
x = torch.randn(3)
4751+
y = MyClass2
4752+
self.assertEqual(fn(x, y), compiled_fn(x, y))
4753+
self.assertEqual(cnts.frame_count, 2)
4754+
47264755
def test_id_guarded_object(self):
47274756
class UDO:
47284757
@torch.compile(backend="eager")

torch/_dynamo/variables/builtin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,9 +1992,11 @@ def call_id(self, tx: "InstructionTranslator", *args):
19921992
mod = tx.output.get_submodule(nn_mod_variable.module_key)
19931993
return variables.ConstantVariable.create(id(mod))
19941994
elif len(args) == 1 and isinstance(
1995-
args[0], variables.UserDefinedObjectVariable
1995+
args[0],
1996+
(variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable),
19961997
):
1997-
install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH))
1998+
if args[0].source:
1999+
install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH))
19982000
constant_result = id(args[0].value)
19992001
return variables.ConstantVariable.create(constant_result)
20002002
elif len(args) == 1 and isinstance(args[0], TensorVariable):

0 commit comments

Comments
 (0)
0