8000 Update · pytorch/pytorch@b98222d · GitHub
[go: up one dir, main page]

Skip to content

Commit b98222d

Browse files
committed
Update
[ghstack-poisoned]
1 parent c655d97 commit b98222d

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

test/dynamo/test_misc.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4723,6 +4723,35 @@ def fn_has_breaks(x):
47234723
opt_fn(x)
47244724
self.assertEqual(cnts.frame_count, 2)
47254725

4726+
def test_id_guarded_class(self):
4727+
class MyClass1:
4728+
pass
4729+
4730+
class MyClass2:
4731+
pass
4732+
4733+
def fn(x, y):
4734+
return x + id(y) // 100000
4735+
4736+
cnts = torch._dynamo.testing.CompileCounter()
4737+
compiled_fn = torch.compile(backend=cnts, fullgraph=True)(fn)
4738+
x = torch.randn(3)
4739+
y = MyClass1
4740+
self.assertEqual(fn(x, y), compiled_fn(x, y))
4741+
self.assertEqual(cnts.frame_count, 1)
4742+
4743+
# No recompile if still pass in the original class (MyClass1)
4744+
x = torch.randn(3)
4745+
y = MyClass1
4746+
self.assertEqual(fn(x, y), compiled_fn(x, y))
4747+
self.assertEqual(cnts.frame_count, 1)
4748+
4749+
# Have to recompile if pass in new class (MyClass2)
4750+
x = torch.randn(3)
4751+
y = MyClass2
4752+
self.assertEqual(fn(x, y), compiled_fn(x, y))
4753+
self.assertEqual(cnts.frame_count, 2)
4754+
47264755
def test_id_guarded_object(self):
47274756
class UDO:
47284757
@torch.compile(backend="eager")

torch/_dynamo/variables/builtin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,9 +1989,11 @@ def call_id(self, tx: "InstructionTranslator", *args):
19891989
mod = tx.output.get_submodule(nn_mod_variable.module_key)
19901990
return variables.ConstantVariable.create(id(mod))
19911991
elif len(args) == 1 and isinstance(
1992-
args[0], variables.UserDefinedObjectVariable
1992+
args[0],
1993+
(variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable),
19931994
):
1994-
install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH))
1995+
if args[0].source:
1996+
install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH))
19951997
constant_result = id(args[0].value)
19961998
return variables.ConstantVariable.create(constant_result)
19971999
elif len(args) == 1 and isinstance(args[0], TensorVariable):

0 commit comments

Comments
 (0)
0