8000 [ONNX] Add support for torch.cond/HOP in onnx exporter by xadupre · Pull Request #137428 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Add support for torch.cond/HOP in onnx exporter #137428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 47 commits into from

Conversation

xadupre
Copy link
Collaborator
@xadupre xadupre commented Oct 7, 2024

This PR implements the framework for supporting HOP in the ONNX exporter. Refer to #140995 for the design.

  • Implement support for torch.cond
  • Refactor _add_nodes into _translate_fx_graph to handle nested subgraphs. To support building subgraphs as functions using the same logic, new handlers for placeholder and output nodes are added to register inputs and outputs on the onnx function.
  • Fuctions are created under the domain of pkg.torch.__subgraph__
  • Updated the type promotion pass to run on nested subgraphs.
  • Implement torch.cond in _torchlib/ops/hop.py. Updated the registry to discover these ops.
  • Improve opset_import handling robustness with add_opset_imports IR pass. To achieve this, we added opset version to all Nodes. Fixes [ONNX] Add opset version to individual nodes when building the graph #139503

Fixes #117655 Fixes #123972 Ref #93743 Closes #140995

Copy link
pytorch-bot bot commented Oct 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137428

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures, 1 Unrelated Failure

As of commit 9f081b8 with merge base 72943ba (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Oct 7, 2024
@xadupre xadupre changed the title [ONNX] Add support for torch.cond in onnx exporter [ONNX] [WIP] Add support for torch.cond in onnx exporter Oct 7, 2024
@justinchuby justinchuby self-assigned this Oct 7, 2024
@titaiwangms titaiwangms added topic: bug fixes topic category module: onnx Related to torch.onnx labels Oct 7, 2024
@justinchuby
Copy link
Collaborator

Function vs subgraph: Could you share more details on how the fx subgraphs are represented differently than onnx subgraphs? This can inform us on what the best representation is.

Copy link
Collaborator
@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

  1. I think we should isolate logic handling a particular operator from the _core module. Any logic there would better be generic for handling HOPs, and specific logic for handling cond can be in a separate location which can be later integrated into torchlib (when it is migrated over)
  2. In general, we should avoid any proto object manipulation in the exporter directory as they introduce unnecessary overhead and creates inconsistencies in how we manipulate the onnx graph.

@xadupre
Copy link
Collaborator Author
xadupre commented Nov 4, 2024

Thanks!

  1. I think we should isolate logic handling a particular operator from the _core module. Any logic there would better be generic for handling HOPs, and specific logic for handling cond can be in a separate location which can be later integrated into torchlib (when it is migrated over)
  2. In general, we should avoid any proto object manipulation in the exporter directory as they introduce unnecessary overhead and creates inconsistencies in how we manipulate the onnx graph.
  1. Let's do it in two steps: first this one, then HOPs
  2. That would be better. I'll give it a try. My first goal was to check it is working.

@xadupre xadupre changed the title [ONNX] [WIP] Add support for torch.cond in onnx exporter [ONNX] Add support for torch.cond in onnx exporter Nov 5, 2024
@justinchuby
Copy link
Collaborator

@xadupre and I talked and we have some ideas to simplify some of the function calls. I will play with it a little and update here if I make any progress

fx_graph = module.graph

graph_like: ir.Graph | ir.Function
if name == "":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think torch.export would change this assumption? I am a bit concerned.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the name? I think the root module always have an empty name, but I should double check

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asked in slack

@justinchuby justinchuby added this to the 2.6.0 milestone Nov 20, 2024
@justinchuby
Copy link
Collaborator
<
    ir_version=9,
    opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18, 'pkg.torch.__subgraph__': 1, 'pkg.onnxscript.torch_lib': 1},
    producer_name='pytorch',
    producer_version='2.6.0a0+git3cd6dd5',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"x"<INT64,[2]>
    ),
    outputs=(
        %"getitem"<FLOAT,[2]>
    ),
    initializers=(
        %"weight"<FLOAT,[1]>,
        %"submodule.weight"<FLOAT,[1]>
    ),
) {
    0 |  # node_ReduceSum_0
         %"sum_1"<INT64,[]> ⬅️ ::ReduceSum(%"x") {keepdims=False, noop_with_empty_axes=0}
    1 |  # node_Constant_1
         %"val_0"<?,?> ⬅️ ::Constant() {value=Tensor<INT64,[]>(array(0), name=None)}
    2 |  # node_Greater_2
         %"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"val_0")
    3 |  # node_If_3
         %"getitem"<FLOAT,[2]> ⬅️ ::If(%"gt") {then_branch=
             graph(
                 name=true_graph_0,
                 inputs=(

                 ),
                 outputs=(
                     %"getitem_true_graph_0"<?,?>
                 ),
             ) {
                 0 |  # node_true_graph_0_0
                      %"getitem_true_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::true_graph_0(%"weight", %"x", %"submodule.weight")
                 return %"getitem_true_graph_0"<?,?>
             }, else_branch=
             graph(
                 name=false_graph_0,
                 inputs=(

                 ),
                 outputs=(
                     %"sub_false_graph_0"<?,?>
                 ),
             ) {
                 0 |  # node_false_graph_0_0
                      %"sub_false_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::false_graph_0(%"weight", %"x", %"submodule.weight")
                 return %"sub_false_graph_0"<?,?>
             }}
    return %"getitem"<FLOAT,[2]>
}

<
    opset_imports={'': 18},
>
def pkg.torch.__subgraph__::false_graph_0(
    inputs=(
        %"p_weight"<FLOAT,[1]>,
        %"x"<INT64,[2]>,
        %"p_submodule_weight"<FLOAT,[1]>
    ),
    outputs=(
        %"sub"<FLOAT,[2]>
    ),
) {
    0 |  # node_Cast_0
         %"convert_element_type_default"<FLOAT,[2]> ⬅️ ::Cast(%"x") {to=FLOAT}
    1 |  # node_Sub_1
         %"sub"<FLOAT,[2]> ⬅️ ::Sub(%"convert_element_type_default", %"p_weight")
    return %"sub"<FLOAT,[2]>
}

<
    opset_imports={'pkg.onnxscript.torch_lib': 1},
>
def pkg.torch.__subgraph__::true_graph_0__false_graph_0(
    inputs=(
        %"p_submodule_weight"<FLOAT,[1]>,
        %"sub"<FLOAT,[2]>
    ),
    outputs=(
        %"div"<FLOAT,[2]>
    ),
) {
    0 |  # node_aten_div_0
         %"div"<FLOAT,[2]> ⬅️ pkg.onnxscript.torch_lib::aten_div(%"sub", %"p_submodule_weight")
    return %"div"<FLOAT,[2]>
}

<
    opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib::aten_div(
    inputs=(
        %"self"<?,?>,
        %"other"<?,?>
    ),
    outputs=(
        %"return_val"<?,?>
    ),
) {
    0 |  # n0
         %"return_val"<?,?> ⬅️ ::Div(%"self", %"other")
    return %"return_val"<?,?>
}

<
    opset_imports={'': 18},
>
def pkg.torch.__subgraph__::true_graph_0__true_graph_0(
    inputs=(
        %"p_submodule_weight"<FLOAT,[1]>,
        %"sub"<FLOAT,[2]>
    ),
    outputs=(
        %"mul"<FLOAT,[2]>
    ),
) {
    0 |  # node_Mul_0
         %"mul"<FLOAT,[2]> ⬅️ ::Mul(%"sub", %"p_submodule_weight")
    return %"mul"<FLOAT,[2]>
}

<
    opset_imports={'': 18, 'pkg.torch.__subgraph__': 1},
>
def pkg.torch.__subgraph__::true_graph_0(
    inputs=(
        %"p_weight"<FLOAT,[1]>,
        %"x"<INT64,[2]>,
        %"p_submodule_weight"<FLOAT,[1]>
    ),
    outputs=(
        %"getitem"<FLOAT,[2]>
    ),
) {
    0 |  # node_Cast_0
         %"convert_element_type_default"<FLOAT,[2]> ⬅️ ::Cast(%"x") {to=FLOAT}
    1 |  # node_Sub_1
         %"sub"<FLOAT,[2]> ⬅️ ::Sub(%"convert_element_type_default", %"p_weight")
    2 |  # node_ReduceSum_2
         %"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"sub") {keepdims=False, noop_with_empty_axes=0}
    3 |  # node_Constant_3
         %"val_0"<?,?> ⬅️ ::Constant() {value=Tensor<INT64,[]>(array(0), name=None)}
    4 |  # node_Cast_4
         %"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Cast(%"val_0") {to=FLOAT}
    5 |  # node_LessOrEqual_5
         %"le"<BOOL,[]> ⬅️ ::LessOrEqual(%"sum_1", %"scalar_tensor_default")
    6 |  # node_If_6
         %"getitem"<FLOAT,[2]> ⬅️ ::If(%"le") {then_branch=
             graph(
                 name=true_graph_0__true_graph_0,
                 inputs=(

                 ),
                 outputs=(
                     %"mul_true_graph_0__true_graph_0"<?,?>
                 ),
             ) {
                 0 |  # node_true_graph_0__true_graph_0_0
                      %"mul_true_graph_0__true_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::true_graph_0__true_graph_0(%"p_submodule_weight", %"sub")
                 return %"mul_true_graph_0__true_graph_0"<?,?>
             }, else_branch=
             graph(
                 name=true_graph_0__false_graph_0,
                 inputs=(

                 ),
                 outputs=(
                     %"div_true_graph_0__false_graph_0"<?,?>
                 ),
             ) {
                 0 |  # node_true_graph_0__false_graph_0_0
                      %"div_true_graph_0__false_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::true_graph_0__false_graph_0(%"p_submodule_weight", %"sub")
                 return %"div_true_graph_0__false_graph_0"<?,?>
             }}
    return %"getitem"<FLOAT,[2]>
}

<
    opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::Rank(
    inputs=(
        %"input"<?,?>
    ),
    outputs=(
        %"return_val"<?,?>
    ),
) {
    0 |  # n0
         %"tmp"<?,?> ⬅️ ::Shape(%"input")
    1 |  # n1
         %"return_val"<?,?> ⬅️ ::Size(%"tmp")
    return %"return_val"<?,?>
}

<
    opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::IsScalar(
    inputs=(
        %"input"<?,?>
    ),
    outputs=(
        %"return_val"<?,?>
    ),
) {
    0 |  # n0
         %"tmp"<?,?> ⬅️ ::Shape(%"input")
    1 |  # n1
         %"tmp_0"<?,?> ⬅️ ::Size(%"tmp")
    2 |  # n2
         %"tmp_1"<?,?> ⬅️ ::Constant() {value_int=0}
    3 |  # n3
         %"return_val"<?,?> ⬅️ ::Equal(%"tmp_0", %"tmp_1")
    return %"return_val"<?,?>
}

@justinchuby
Copy link
Collaborator

@pytorchbot merge -i

@justinchuby
Copy link
Collaborator

Will address new comments in a follow up PR

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 4 checks: pull / linux-focal-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge), pull / linux-jammy-py3-clang12-executorch / test (executorch, 1, 1, lf.linux.2xlarge), pull / linux-jammy-py3.9-gcc11 / test (docs_test, 1, 1, lf.linux.2xlarge), pull / linux-jammy-py3.9-gcc11 / test (backwards_compat, 1, 1, lf.linux.2xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

youssef62 pushed a commit to youssef62/pytorch that referenced this pull request Nov 23, 2024
This PR implements the framework for supporting HOP in the ONNX exporter. Refer to pytorch#140995 for the design.

- Implement support for torch.cond
- Refactor `_add_nodes` into `_translate_fx_graph` to handle nested subgraphs. To support building subgraphs as functions using the same logic, new handlers for `placeholder` and `output` nodes are added to register inputs and outputs on the onnx function.
- Fuctions are created under the domain of `pkg.torch.__subgraph__`
- Updated the type promotion pass to run on nested subgraphs.
- Implement torch.cond in `_torchlib/ops/hop.py`. Updated the registry to discover these ops.
- Improve opset_import handling robustness with `add_opset_imports` IR pass. To achieve this, we added opset version to all Nodes. Fixes pytorch#139503

Fixes pytorch#117655 Fixes pytorch#123972 Fixes pytorch#93743 Closes pytorch#140995

Pull Request resolved: pytorch#137428
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
This PR implements the framework for supporting HOP in the ONNX exporter. Refer to pytorch#140995 for the design.

- Implement support for torch.cond
- Refactor `_add_nodes` into `_translate_fx_graph` to handle nested subgraphs. To support building subgraphs as functions using the same logic, new handlers for `placeholder` and `output` nodes are added to register inputs and outputs on the onnx function.
- Fuctions are created under the domain of `pkg.torch.__subgraph__`
- Updated the type promotion pass to run on nested subgraphs.
- Implement torch.cond in `_torchlib/ops/hop.py`. Updated the registry to discover these ops.
- Improve opset_import handling robustness with `add_opset_imports` IR pass. To achieve this, we added opset version to all Nodes. Fixes pytorch#139503

Fixes pytorch#117655 Fixes pytorch#123972 Fixes pytorch#93743 Closes pytorch#140995

Pull Request resolved: pytorch#137428
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
This PR implements the framework for supporting HOP in the ONNX exporter. Refer to pytorch#140995 for the design.

- Implement support for torch.cond
- Refactor `_add_nodes` into `_translate_fx_graph` to handle nested subgraphs. To support building subgraphs as functions using the same logic, new handlers for `placeholder` and `output` nodes are added to register inputs and outputs on the onnx function.
- Fuctions are created under the domain of `pkg.torch.__subgraph__`
- Updated the type promotion pass to run on nested subgraphs.
- Implement torch.cond in `_torchlib/ops/hop.py`. Updated the registry to discover these ops.
- Improve opset_import handling robustness with `add_opset_imports` IR pass. To achieve this, we added opset version to all Nodes. Fixes pytorch#139503

Fixes pytorch#117655 Fixes pytorch#123972 Fixes pytorch#93743 Closes pytorch#140995

Pull Request resolved: pytorch#137428
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: new features topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
7 participants
0