10000 Make torch_geometric models compatible with export (#123403) · pytorch/pytorch@d78991a · GitHub
[go: up one dir, main page]

Skip to content

Commit d78991a

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Make torch_geometric models compatible with export (#123403)
Pull Request resolved: #123403 Approved by: https://github.com/angelayi
1 parent cbde0f0 commit d78991a

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

benchmarks/dynamo/common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,12 +1135,12 @@ def load(cls, model, example_inputs, device):
11351135
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
11361136
_register_dataclass_output_as_pytree(example_outputs)
11371137

1138-
# TODO(angelayi): change this to predispatch
1139-
gm = torch.export._trace._export_to_torch_ir(
1138+
gm = torch.export._trace._export(
11401139
model,
11411140
example_args,
11421141
example_kwargs,
1143-
)
1142+
pre_dispatch=True,
1143+
).module()
11441144
with torch.no_grad():
11451145
so_path = torch._inductor.aot_compile(
11461146
gm, example_args, example_kwargs

benchmarks/dynamo/torchbench.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,20 @@
2525
torch.backends.cuda.matmul.allow_tf32 = True
2626

2727

28+
def _reassign_parameters(model):
29+
# torch_geometric models register parameter as tensors due to
30+
# https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/dense/linear.py#L158-L168
31+
# Since it is unusual thing to do, we just reassign them to parameters
32+
def state_dict_hook(module, destination, prefix, local_metadata):
33+
for name, param in module.named_parameters():
34+
if isinstance(destination[name], torch.Tensor) and not isinstance(
35+
destination[name], torch.nn.Parameter
36+
):
37+
destination[name] = torch.nn.Parameter(destination[name])
38+
39+
model._register_state_dict_hook(state_dict_hook)
40+
41+
2842
def setup_torchbench_cwd():
2943
original_dir = abspath(os.getcwd())
3044

@@ -265,6 +279,14 @@ def load_model(
265279
extra_args=extra_args,
266280
)
267281
model, example_inputs = benchmark.get_module()
282+
if model_name in [
283+
"basic_gnn_edgecnn",
284+
"basic_gnn_gcn",
285+
"basic_gnn_sage",
286+
"basic_gnn_gin",
287+
]:
288+
_reassign_parameters(model)
289+
268290
# Models that must be in train mode while training
269291
if is_training and (
270292
not use_eval_mode or model_name in self._config["only_training"]

0 commit comments

Comments
 (0)
0