8000 Revert "Make torch_geometric models compatible with export (#123403)" · pytorch/pytorch@67f846b · GitHub
[go: up one dir, main page]

Skip to content

Commit 67f846b

Browse files
committed
Revert "Make torch_geometric models compatible with export (#123403)"
This reverts commit d78991a. ghstack-source-id: 0ca28fb Pull Request resolved: #128377
1 parent 99f5a85 commit 67f846b

File tree

2 files changed

+5
-25
lines changed

2 files changed

+5
-25
lines changed

benchmarks/dynamo/common.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,12 +1184,14 @@ def load(cls, model, example_inputs, device):
11841184
else:
11851185
_register_dataclass_output_as_pytree(example_outputs)
11861186

1187-
gm = torch.export._trace._export(
1187+
# TODO(angelayi): change this to predispatch
1188+
# https://github.com/pytorch/pytorch/issues/127513 needs to be fixed before changing
1189+
# to predispatch to avoid performance regressions
1190+
gm = torch.export._trace._export_to_torch_ir(
11881191
model,
11891192
example_args,
11901193
example_kwargs,
1191-
pre_dispatch=True,
1192-
).module()
1194+
)
11931195
with torch.no_grad():
11941196
so_path = torch._inductor.aot_compile(
11951197
gm, example_args, example_kwargs

benchmarks/dynamo/torchbench.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,6 @@
2626
torch.backends.cuda.matmul.allow_tf32 = True
2727

2828

29-
def _reassign_parameters(model):
30-
# torch_geometric models register parameter as tensors due to
31-
# https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/dense/linear.py#L158-L168
32-
# Since it is unusual thing to do, we just reassign them to parameters
33-
def state_dict_hook(module, destination, prefix, local_metadata):
34-
for name, param in module.named_parameters():
35-
if isinstance(destination[name], torch.Tensor) and not isinstance(
36-
destination[name], torch.nn.Parameter
37-
):
38-
destination[name] = torch.nn.Parameter(destination[name])
39-
40-
model._register_state_dict_hook(state_dict_hook)
41-
42-
4329
def setup_torchbench_cwd():
4430
original_dir = abspath(os.getcwd())
4531

@@ -314,14 +300,6 @@ def load_model(
314300
extra_args=extra_args,
315301
)
316302
model, example_inputs = benchmark.get_module()
317-
if model_name in [
318-
"basic_gnn_edgecnn",
319-
"basic_gnn_gcn",
320-
"basic_gnn_sage",
321-
"basic_gnn_gin",
322-
]:
323-
_reassign_parameters(model)
324-
325303
# Models that must be in train mode while training
326304
if is_training and (
327305
not use_eval_mode or model_name in self._config["only_training"]

0 commit comments

Comments
 (0)
0