8000 Update on "Revert "Make torch_geometric models compatible with export… · pytorch/pytorch@640cf24 · GitHub
[go: up one dir, main page]

Skip to content

Commit 640cf24

Browse files
committed
Update on "Revert "Make torch_geometric models compatible with export (#123403)""
This reverts commit d78991a. This PR reverts #123403 to fix the performance regression as discussed in #127513 (comment). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
1 parent 84d529d commit 640cf24

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

benchmarks/dynamo/torchbench.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@
2626
torch.backends.cuda.matmul.allow_tf32 = True
2727

2828

29+
def _reassign_parameters(model):
30+
# torch_geometric models register parameter as tensors due to
31+
# https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/dense/linear.py#L158-L168
32+
# Since it is unusual thing to do, we just reassign them to parameters
33+
def state_dict_hook(module, destination, prefix, local_metadata):
34+
for name, param in module.named_parameters():
35+
if isinstance(destination[name], torch.Tensor) and not isinstance(
36+
destination[name], torch.nn.Parameter
37+
):
38+
destination[name] = torch.nn.Parameter(destination[name])
39+
40+
model._register_state_dict_hook(state_dict_hook)
41+
42+
2943
def setup_torchbench_cwd():
3044
original_dir = abspath(os.getcwd())
3145

@@ -300,6 +314,14 @@ def load_model(
300314
extra_args=extra_args,
301315
)
302316
model, example_inputs = benchmark.get_module()
317+
if model_name in [
318+
"basic_gnn_edgecnn",
319+
"basic_gnn_gcn",
320+
"basic_gnn_sage",
321+
"basic_gnn_gin",
322+
]:
323+
_reassign_parameters(model)
324+
303325
# Models that must be in train mode while training
304326
if is_training and (
305327
not use_eval_mode or model_name in self._config["only_training"]

0 commit comments

Comments
 (0)
0