8000 [ONNX] Inline prim::PythonOp for Autograd Function Export by shubhambhokare1 · Pull Request #74765 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Inline prim::PythonOp for Autograd Function Export #74765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add docs
  • Loading branch information
shubhambhokare1 committed Jul 25, 2022
commit 9f3144189f5f7699cc6e99bd1bdd60d535ec3c39
25 changes: 24 additions & 1 deletion docs/source/onnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ See the ``symbolic_opset*.py`` files for more examples.
torch.autograd.Functions
^^^^^^^^^^^^^^^^^^^^^^^^

If the operator is a sub-class of :class:`torch.autograd.Function`, there are two ways
If the operator is a sub-class of :class:`torch.autograd.Function`, there are three ways
to export it.

Static Symbolic Method
Expand Down Expand Up @@ -488,6 +488,29 @@ The example below shows how you can access ``requires_grad`` via the ``Node`` ob
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)

Inline Autograd Function
~~~~~~~~~~~~~~~~~~~~~~~~
In cases where a static symbolic method is not provided for its subsequent autograd.Function
or where a function to register prim::PythonOp as custom symbolic functions is not provided,
torch.onnx.export tries to inline the graph that corresponds to that autograd.Function such that
this function is broken down into individual operators that were used within the function.
The export should be successful as long as these individual operators are supported. For example::

class MyLogExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
h = input.exp()
return h.log().log()

There is no static symbolic method present for this model, yet it is exported as follows::

graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%1 : float = onnx::Exp[](%input)
%2 : float = onnx::Log[](%1)
%3 : float = onnx::Log[](%2)
return (%3)

Custom operators
^^^^^^^^^^^^^^^^

Expand Down
11 changes: 6 additions & 5 deletions test/onnx/test_autograd_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import unittest


class TestAutogradFuns(unittest.TestCase):
opset_version = GLOBALS.export_onnx_opset_version
keep_initializers_as_inputs = False
Expand All @@ -23,7 +24,7 @@ def forward(ctx, i):

@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
(result,) = ctx.saved_tensors
return grad_output * result

class Caller(torch.nn.Module):
Expand All @@ -33,7 +34,7 @@ def forward(self, input):

model = Caller()
input = torch.ones(1)
run_model_test(self, model, input_args=(input,))
run_model_test(self, model, input_args=(input,), verbose=True)

def test_multi_output(self):
class MultiOut(torch.autograd.Function):
Expand All @@ -46,7 +47,7 @@ def forward(ctx, i):

@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
(result,) = ctx.saved_tensors
return grad_output * result

class Caller(torch.nn.Module):
Expand All @@ -68,7 +69,7 @@ def forward(ctx, i):

@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
(result,) = ctx.saved_tensors
return grad_output * result

class Parent(torch.autograd.Function):
Expand All @@ -81,7 +82,7 @@ def forward(ctx, i):

@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
(result,) = ctx.saved_tensors
return grad_output * result

class Caller(torch.nn.Module):
Expand Down
0