-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[ONNX] Add support for torch.cond/HOP in onnx exporter #137428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137428
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 New Failures, 1 Unrelated FailureAs of commit 9f081b8 with merge base 72943ba ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Function vs subgraph: Could you share more details on how the fx subgraphs are represented differently than onnx subgraphs? This can inform us on what the best representation is. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
- I think we should isolate logic handling a particular operator from the _core module. Any logic there would better be generic for handling HOPs, and specific logic for handling cond can be in a separate location which can be later integrated into torchlib (when it is migrated over)
- In general, we should avoid any proto object manipulation in the
exporter
directory as they introduce unnecessary overhead and creates inconsistencies in how we manipulate the onnx graph.
|
@xadupre and I talked and we have some ideas to simplify some of the function calls. I will play with it a little and update here if I make any progress |
fx_graph = module.graph | ||
|
||
graph_like: ir.Graph | ir.Function | ||
if name == "": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think torch.export would change this assumption? I am a bit concerned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean the name? I think the root module always have an empty name, but I should double check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Asked in slack
<
ir_version=9,
opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18, 'pkg.torch.__subgraph__': 1, 'pkg.onnxscript.torch_lib': 1},
producer_name='pytorch',
producer_version='2.6.0a0+git3cd6dd5',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<INT64,[2]>
),
outputs=(
%"getitem"<FLOAT,[2]>
),
initializers=(
%"weight"<FLOAT,[1]>,
%"submodule.weight"<FLOAT,[1]>
),
) {
0 | # node_ReduceSum_0
%"sum_1"<INT64,[]> ⬅️ ::ReduceSum(%"x") {keepdims=False, noop_with_empty_axes=0}
1 | # node_Constant_1
%"val_0"<?,?> ⬅️ ::Constant() {value=Tensor<INT64,[]>(array(0), name=None)}
2 | # node_Greater_2
%"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"val_0")
3 | # node_If_3
%"getitem"<FLOAT,[2]> ⬅️ ::If(%"gt") {then_branch=
graph(
name=true_graph_0,
inputs=(
),
outputs=(
%"getitem_true_graph_0"<?,?>
),
) {
0 | # node_true_graph_0_0
%"getitem_true_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::true_graph_0(%"weight", %"x", %"submodule.weight")
return %"getitem_true_graph_0"<?,?>
}, else_branch=
graph(
name=false_graph_0,
inputs=(
),
outputs=(
%"sub_false_graph_0"<?,?>
),
) {
0 | # node_false_graph_0_0
%"sub_false_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::false_graph_0(%"weight", %"x", %"submodule.weight")
return %"sub_false_graph_0"<?,?>
}}
return %"getitem"<FLOAT,[2]>
}
<
opset_imports={'': 18},
>
def pkg.torch.__subgraph__::false_graph_0(
inputs=(
%"p_weight"<FLOAT,[1]>,
%"x"<INT64,[2]>,
%"p_submodule_weight"<FLOAT,[1]>
),
outputs=(
%"sub"<FLOAT,[2]>
),
) {
0 | # node_Cast_0
%"convert_element_type_default"<FLOAT,[2]> ⬅️ ::Cast(%"x") {to=FLOAT}
1 | # node_Sub_1
%"sub"<FLOAT,[2]> ⬅️ ::Sub(%"convert_element_type_default", %"p_weight")
return %"sub"<FLOAT,[2]>
}
<
opset_imports={'pkg.onnxscript.torch_lib': 1},
>
def pkg.torch.__subgraph__::true_graph_0__false_graph_0(
inputs=(
%"p_submodule_weight"<FLOAT,[1]>,
%"sub"<FLOAT,[2]>
),
outputs=(
%"div"<FLOAT,[2]>
),
) {
0 | # node_aten_div_0
%"div"<FLOAT,[2]> ⬅️ pkg.onnxscript.torch_lib::aten_div(%"sub", %"p_submodule_weight")
return %"div"<FLOAT,[2]>
}
<
opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib::aten_div(
inputs=(
%"self"<?,?>,
%"other"<?,?>
),
outputs=(
%"return_val"<?,?>
),
) {
0 | # n0
%"return_val"<?,?> ⬅️ ::Div(%"self", %"other")
return %"return_val"<?,?>
}
<
opset_imports={'': 18},
>
def pkg.torch.__subgraph__::true_graph_0__true_graph_0(
inputs=(
%"p_submodule_weight"<FLOAT,[1]>,
%"sub"<FLOAT,[2]>
),
outputs=(
%"mul"<FLOAT,[2]>
),
) {
0 | # node_Mul_0
%"mul"<FLOAT,[2]> ⬅️ ::Mul(%"sub", %"p_submodule_weight")
return %"mul"<FLOAT,[2]>
}
<
opset_imports={'': 18, 'pkg.torch.__subgraph__': 1},
>
def pkg.torch.__subgraph__::true_graph_0(
inputs=(
%"p_weight"<FLOAT,[1]>,
%"x"<INT64,[2]>,
%"p_submodule_weight"<FLOAT,[1]>
),
outputs=(
%"getitem"<FLOAT,[2]>
),
) {
0 | # node_Cast_0
%"convert_element_type_default"<FLOAT,[2]> ⬅️ ::Cast(%"x") {to=FLOAT}
1 | # node_Sub_1
%"sub"<FLOAT,[2]> ⬅️ ::Sub(%"convert_element_type_default", %"p_weight")
2 | # node_ReduceSum_2
%"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"sub") {keepdims=False, noop_with_empty_axes=0}
3 | # node_Constant_3
%"val_0"<?,?> ⬅️ ::Constant() {value=Tensor<INT64,[]>(array(0), name=None)}
4 | # node_Cast_4
%"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Cast(%"val_0") {to=FLOAT}
5 | # node_LessOrEqual_5
%"le"<BOOL,[]> ⬅️ ::LessOrEqual(%"sum_1", %"scalar_tensor_default")
6 | # node_If_6
%"getitem"<FLOAT,[2]> ⬅️ ::If(%"le") {then_branch=
graph(
name=true_graph_0__true_graph_0,
inputs=(
),
outputs=(
%"mul_true_graph_0__true_graph_0"<?,?>
),
) {
0 | # node_true_graph_0__true_graph_0_0
%"mul_true_graph_0__true_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::true_graph_0__true_graph_0(%"p_submodule_weight", %"sub")
return %"mul_true_graph_0__true_graph_0"<?,?>
}, else_branch=
graph(
name=true_graph_0__false_graph_0,
inputs=(
),
outputs=(
%"div_true_graph_0__false_graph_0"<?,?>
),
) {
0 | # node_true_graph_0__false_graph_0_0
%"div_true_graph_0__false_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::true_graph_0__false_graph_0(%"p_submodule_weight", %"sub")
return %"div_true_graph_0__false_graph_0"<?,?>
}}
return %"getitem"<FLOAT,[2]>
}
<
opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::Rank(
inputs=(
%"input"<?,?>
),
outputs=(
%"return_val"<?,?>
),
) {
0 | # n0
%"tmp"<?,?> ⬅️ ::Shape(%"input")
1 | # n1
%"return_val"<?,?> ⬅️ ::Size(%"tmp")
return %"return_val"<?,?>
}
<
opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::IsScalar(
inputs=(
%"input"<?,?>
),
outputs=(
%"return_val"<?,?>
),
) {
0 | # n0
%"tmp"<?,?> ⬅️ ::Shape(%"input")
1 | # n1
%"tmp_0"<?,?> ⬅️ ::Size(%"tmp")
2 | # n2
%"tmp_1"<?,?> ⬅️ ::Constant() {value_int=0}
3 | # n3
%"return_val"<?,?> ⬅️ ::Equal(%"tmp_0", %"tmp_1")
return %"return_val"<?,?>
} |
@pytorchbot merge -i |
Will address new comments in a follow up PR |
Merge startedYour change will be merged while ignoring the following 4 checks: pull / linux-focal-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge), pull / linux-jammy-py3-clang12-executorch / test (executorch, 1, 1, lf.linux.2xlarge), pull / linux-jammy-py3.9-gcc11 / test (docs_test, 1, 1, lf.linux.2xlarge), pull / linux-jammy-py3.9-gcc11 / test (backwards_compat, 1, 1, lf.linux.2xlarge) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR implements the framework for supporting HOP in the ONNX exporter. Refer to pytorch#140995 for the design. - Implement support for torch.cond - Refactor `_add_nodes` into `_translate_fx_graph` to handle nested subgraphs. To support building subgraphs as functions using the same logic, new handlers for `placeholder` and `output` nodes are added to register inputs and outputs on the onnx function. - Fuctions are created under the domain of `pkg.torch.__subgraph__` - Updated the type promotion pass to run on nested subgraphs. - Implement torch.cond in `_torchlib/ops/hop.py`. Updated the registry to discover these ops. - Improve opset_import handling robustness with `add_opset_imports` IR pass. To achieve this, we added opset version to all Nodes. Fixes pytorch#139503 Fixes pytorch#117655 Fixes pytorch#123972 Fixes pytorch#93743 Closes pytorch#140995 Pull Request resolved: pytorch#137428 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This PR implements the framework for supporting HOP in the ONNX exporter. Refer to pytorch#140995 for the design. - Implement support for torch.cond - Refactor `_add_nodes` into `_translate_fx_graph` to handle nested subgraphs. To support building subgraphs as functions using the same logic, new handlers for `placeholder` and `output` nodes are added to register inputs and outputs on the onnx function. - Fuctions are created under the domain of `pkg.torch.__subgraph__` - Updated the type promotion pass to run on nested subgraphs. - Implement torch.cond in `_torchlib/ops/hop.py`. Updated the registry to discover these ops. - Improve opset_import handling robustness with `add_opset_imports` IR pass. To achieve this, we added opset version to all Nodes. Fixes pytorch#139503 Fixes pytorch#117655 Fixes pytorch#123972 Fixes pytorch#93743 Closes pytorch#140995 Pull Request resolved: pytorch#137428 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This PR implements the framework for supporting HOP in the ONNX exporter. Refer to pytorch#140995 for the design. - Implement support for torch.cond - Refactor `_add_nodes` into `_translate_fx_graph` to handle nested subgraphs. To support building subgraphs as functions using the same logic, new handlers for `placeholder` and `output` nodes are added to register inputs and outputs on the onnx function. - Fuctions are created under the domain of `pkg.torch.__subgraph__` - Updated the type promotion pass to run on nested subgraphs. - Implement torch.cond in `_torchlib/ops/hop.py`. Updated the registry to discover these ops. - Improve opset_import handling robustness with `add_opset_imports` IR pass. To achieve this, we added opset version to all Nodes. Fixes pytorch#139503 Fixes pytorch#117655 Fixes pytorch#123972 Fixes pytorch#93743 Closes pytorch#140995 Pull Request resolved: pytorch#137428 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This PR implements the framework for supporting HOP in the ONNX exporter. Refer to #140995 for the design.
_add_nodes
into_translate_fx_graph
to handle nested subgraphs. To support building subgraphs as functions using the same logic, new handlers forplaceholder
andoutput
nodes are added to register inputs and outputs on the onnx function.pkg.torch.__subgraph__
_torchlib/ops/hop.py
. Updated the registry to discover these ops.add_opset_imports
IR pass. To achieve this, we added opset version to all Nodes. Fixes [ONNX] Add opset version to individual nodes when building the graph #139503Fixes #117655 Fixes #123972 Ref #93743 Closes #140995