8000 [Dynamo][Trace PyDispatcher] Support calling id function over class by yanboliang · Pull Request #146269 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Dynamo][Trace PyDispatcher] Support calling id function over class #146269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4723,6 +4723,35 @@ def fn_has_breaks(x):
opt_fn(x)
self.assertEqual(cnts.frame_count, 2)

def test_id_guarded_class(self):
class MyClass1:
pass

class MyClass2:
pass

def fn(x, y):
return x + id(y) // 100000

cnts = torch._dynamo.testing.CompileCounter()
compiled_fn = torch.compile(backend=cnts, fullgraph=True)(fn)
x = torch.randn(3)
y = MyClass1
self.assertEqual(fn(x, y), compiled_fn(x, y))
self.assertEqual(cnts.frame_count, 1)

# No recompile if still pass in the original class (MyClass1)
x = torch.randn(3)
y = MyClass1
self.assertEqual(fn(x, y), compiled_fn(x, y))
self.assertEqual(cnts.frame_count, 1)

# Have to recompile if pass in new class (MyClass2)
x = torch.randn(3)
y = MyClass2
self.assertEqual(fn(x, y), compiled_fn(x, y))
self.assertEqual(cnts.frame_count, 2)

def test_id_guarded_object(self):
class UDO:
@torch.compile(backend="eager")
Expand Down
6 changes: 4 additions & 2 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,9 +1989,11 @@ def call_id(self, tx: "InstructionTranslator", *args):
mod = tx.output.get_submodule(nn_mod_variable.module_key)
return variables.ConstantVariable.create(id(mod))
elif len(args) == 1 and isinstance(
args[0], variables.UserDefinedObjectVariable
args[0],
(variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable),
):
install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH))
if args[0].source:
install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH))
constant_result = id(args[0].value)
return variables.ConstantVariable.create(constant_result)
elif len(args) == 1 and isinstance(args[0], TensorVariable):
Expand Down
Loading
0