8000 Revert "Make torch_geometric models compatible with export (#123403)" by chunyuan-w · Pull Request #128377 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Revert "Make torch_geometric models compatible with export (#123403)" #128377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update on "Revert "Make torch_geometric models compatible with export (
…#123403)""


This reverts commit d78991a.

This PR reverts #123403 to fix the performance regression as discussed in #127513 (comment).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
  • Loading branch information
chunyuan-w committed Jun 12, 2024
commit 640cf24e9e25a281125ee1ad8fb68e9c9e91b9c7
22 changes: 22 additions & 0 deletions benchmarks/dynamo/torchbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@
torch.backends.cuda.matmul.allow_tf32 = True


def _reassign_parameters(model):
# torch_geometric models register parameter as tensors due to
# https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/dense/linear.py#L158-L168
# Since it is unusual thing to do, we just reassign them to parameters
def state_dict_hook(module, destination, prefix, local_metadata):
for name, param in module.named_parameters():
if isinstance(destination[name], torch.Tensor) and not isinstance(
destination[name], torch.nn.Parameter
):
destination[name] = torch.nn.Parameter(destination[name])

model._register_state_dict_hook(state_dict_hook)


def setup_torchbench_cwd():
original_dir = abspath(os.getcwd())

Expand Down Expand Up @@ -300,6 +314,14 @@ def load_model(
extra_args=extra_args,
)
model, example_inputs = benchmark.get_module()
if model_name in [
"basic_gnn_edgecnn",
"basic_gnn_gcn",
"basic_gnn_sage",
"basic_gnn_gin",
]:
_reassign_parameters(model)

# Models that must be in train mode while training
if is_training and (
not use_eval_mode or model_name in self._config["only_training"]
Expand Down
0