8000 [ONNX] Add support for torch.cond/HOP in onnx exporter (#137428) · pytorch/pytorch@0a4bcbf · GitHub
[go: up one dir, main page]

Skip to content

Commit 0a4bcbf

Browse files
xaduprejustinchuby
authored andcommitted
[ONNX] Add support for torch.cond/HOP in onnx exporter (#137428)
This PR implements the framework for supporting HOP in the ONNX exporter. Refer to #140995 for the design. - Implement support for torch.cond - Refactor `_add_nodes` into `_translate_fx_graph` to handle nested subgraphs. To support building subgraphs as functions using the same logic, new handlers for `placeholder` and `output` nodes are added to register inputs and outputs on the onnx function. - Fuctions are created under the domain of `pkg.torch.__subgraph__` - Updated the type promotion pass to run on nested subgraphs. - Implement torch.cond in `_torchlib/ops/hop.py`. Updated the registry to discover these ops. - Improve opset_import handling robustness with `add_opset_imports` IR pass. To achieve this, we added opset version to all Nodes. Fixes #139503 Fixes #117655 Fixes #123972 Fixes #93743 Closes #140995 Pull Request resolved: #137428 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent e0482fd commit 0a4bcbf

File tree

11 files changed

+562
-58
lines changed

11 files changed

+562
-58
lines changed

test/onnx/exporter/test_small_models_e2e.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,75 @@ def forward(self, x):
5858
onnx_testing.assert_onnx_program(onnx_program)
5959
self.assertNotIn("Cast", [node.op_type for node in onnx_program.model.graph])
6060

61+
def test_onnx_export_control_flow(self):
62+
class CondModel(torch.nn.Module):
63+
def forward(self, x):
64+
def true_fn(x):
65+
return x + 1.0
66+
67+
def false_fn(x):
68+
return x - 42.0
69+
70+
y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
71+
return y
72+
73+
onnx_program = torch.onnx.export(
74+
CondModel(),
75+
(torch.tensor([1, 2]),),
76+
dynamo=True,
77+
fallback=False,
78+
)
79+
onnx_model = onnx_program.model
80+
self.assertIn("If", [node.op_type for node in onnx_model.graph])
81+
onnx_testing.assert_onnx_program(onnx_program)
82+
# Test different branches
83+
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([-1, -2]),))
84+
85+
def test_onnx_export_nested_control_flow_and_nested_weights(self):
86+
class Submodule(torch.nn.Module):
87+
def __init__(self):
88+
super().__init__()
89+
# Nested weight
90+
self.weight = torch.nn.Parameter(torch.tensor([100.0]))
91+
92+
def forward(self, x):
93+
def true_fn(x):
94+
return x * self.weight
95+
96+
def false_fn(x):
97+
return x / self.weight
98+
99+
y = torch.cond(x.sum() <= 0, true_fn, false_fn, [x])
100+
return y
101+
102+
class CondModel(torch.nn.Module):
103+
def __init__(self):
104+
super().__init__()
105+
self.submodule = Submodule()
106+
self.weight = torch.nn.Parameter(torch.tensor([42.0]))
107+
108+
def forward(self, x):
109+
def true_fn(x):
110+
return self.submodule(x - self.weight)
111+
112+
def false_fn(x):
113+
return x - self.weight
114+
115+
y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
116+
return y
117+
118+
onnx_program = torch.onnx.export(
119+
CondModel(),
120+
(torch.tensor([1, 2]),),
121+
dynamo=True,
122+
fallback=False,
123+
)
124+
onnx_testing.assert_onnx_program(onnx_program)
125+
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([0, 0]),))
126+
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([43, 43]),))
127+
128+
# TODO(justinchuby): Test multi-output HOPs
129+
61130

62131
if __name__ == "__main__":
63132
common_utils.run_tests()

torch/onnx/_internal/exporter/_building.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def _construct_node(
492492
inputs=inputs,
493493
attributes=attributes,
494494< 6D40 /td>
outputs=outputs,
495+
version=signature.opset_version,
495496
)
496497

497498

@@ -503,7 +504,9 @@ def __init__(
503504
):
504505
self.nodes: list[ir.Node] = []
505506
self.opset = opset
506-
self.functions: dict[ir.OperatorIdentifier, onnxscript.OnnxFunction] = {}
507+
self.functions: dict[
508+
ir.OperatorIdentifier, onnxscript.OnnxFunction | ir.Function
509+
] = {}
507510
self.constant_farm = constant_farm
508511

509512
def _call_op(
@@ -644,7 +647,10 @@ def eval_function( # type: ignore[override]
644647
op_signature = function.signature
645648
else:
646649
op_signature = _schemas.OpSignature.from_function(
647-
function, function.function_ir.domain, function.name
650+
function,
651+
function.function_ir.domain,
652+
function.name,
653+
opset_version=function.opset.version,
648654
)
649655

650656
named_inputs, named_attrs = _construct_named_inputs_and_attrs(
@@ -683,7 +689,11 @@ def eval_function( # type: ignore[override]
683689

684690
return function.function(**converted_named_inputs, **named_attrs)
685691

686-
outputs = self._call_op(op_signature, named_inputs, named_attrs)
692+
outputs = self._call_op(
693+
op_signature,
694+
named_inputs,
695+
named_attrs,
696+
)
687697

688698
self.functions[(function.function_ir.domain, function.name, "")] = function
689699
if len(outputs) == 1:

0 commit comments

Comments
 (0)
0