8000 [ONNX] Inline prim::PythonOp for Autograd Function Export (#74765) (#… · pytorch/pytorch@f11048c · GitHub
[go: up one dir, main page]

Skip to content

Commit f11048c

Browse files
[ONNX] Inline prim::PythonOp for Autograd Function Export (#74765) (#74765)
Summary: Add flag (inline_autograd) to enable inline export of model consisting of autograd functions. Currently, this flag should only be used in TrainingMode.EVAL and not for training. An example: If a model containing ``autograd.Function`` is as follows ``` class AutogradFunc(torch.autograd.Function): staticmethod def forward(ctx, i): result = i.exp() result = result.log() ctx.save_for_backward(result) return result ``` Then the model is exported as ``` graph(%0 : Float): %1 : Float = ^AutogradFunc(%0) return (%1) ``` If inline_autograd is set to True, this will be exported as ``` graph(%0 : Float): %1 : Float = onnx::Exp(%0) %2 : Float = onnx::Log(%1) return (%2) ``` If one of the ops within the autograd module is not supported, that particular node is exported as is mirroring ONNX_FALLTHROUGH mode Fixes: #61813 Pull Request resolved: #74765 Approved by: https://github.com/BowenBao, https://github.com/malfet Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/95d873855e6b5a7b44e102d3aec81d6db3215c0f Original Phabricator Test Plan: Imported from GitHub, without a `Test Plan:` line. Reviewed By: george-qi, kit1980 Differential Revision: D37738323 fbshipit-source-id: 03ff75a809403b134c2a545952706cbeac8d0065
1 parent c27b395 commit f11048c

File tree

12 files changed

+381
-4
lines changed

12 files changed

+381
-4
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ libtorch_python_core_sources = [
896896
"torch/csrc/jit/passes/onnx/function_extraction.cpp",
897897
"torch/csrc/jit/passes/onnx/onnx_log.cpp",
898898
"torch/csrc/jit/python/pybind_utils.cpp",
899+
"torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.cpp",
899900
"torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp",
900901
"torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp" 67E6 ,
901902
"torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp",

docs/source/onnx.rst

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ See the ``symbolic_opset*.py`` files for more examples.
410410
torch.autograd.Functions
411411
^^^^^^^^^^^^^^^^^^^^^^^^
412412

413-
If the operator is a sub-class of :class:`torch.autograd.Function`, there are two ways
413+
If the operator is a sub-class of :class:`torch.autograd.Function`, there are three ways
414414
to export it.
415415

416416
Static Symbolic Method
@@ -488,6 +488,32 @@ The example below shows how you can access ``requires_grad`` via the ``Node`` ob
488488
from torch.onnx import register_custom_op_symbolic
489489
register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)
490490

491+
Inline Autograd Function
492+
~~~~~~~~~~~~~~~~~~~~~~~~
493+
In cases where a static symbolic method is not provided for its subsequent autograd.Function
494+
or where a function to register prim::PythonOp as custom symbolic functions is not provided,
495+
torch.onnx.export tries to inline the graph that corresponds to that autograd.Function such that
496+
this function is broken down into individual operators that were used within the function.
497+
The export should be successful as long as these individual operators are supported. For example::
498+
499+
class MyLogExp(torch.autograd.Function):
500+
@staticmethod
501+
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
502+
ctx.save_for_backward(input)
503+
h = input.exp()
504+
return h.log().log()
505+
506+
There is no static symbolic method present for this model, yet it is exported as follows::
507+
508+
graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
509+
%1 : float = onnx::Exp[](%input)
510+
%2 : float = onnx::Log[](%1)
511+
%3 : float = onnx::Log[](%2)
512+
return (%3)
513+
514+
In order to avoid inlining of autograd.Functions, model should be exported with
515+
operator_export_type set to ONNX_FALLTHROUGH or ONNX_ATEN_FALLBACK mode
516+
491517
Custom operators
492518
^^^^^^^^^^^^^^^^
493519

test/onnx/test_autograd_funs.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Owner(s): ["module: onnx"]
2+
3+
import unittest
4+
5+
import torch
6+
7+
from onnx_test_common import run_model_test
8+
from torch.onnx import OperatorExportTypes
9+
from torch.onnx._globals import GLOBALS
10+
from torch.onnx.utils import _model_to_graph
11+
12+
13+
class TestAutogradFuns(unittest.TestCase):
14+
opset_version = GLOBALS.export_onnx_opset_version
15+
keep_initializers_as_inputs = False
16+
onnx_shape_inference = True
17+
18+
def test_single_output(self):
19+
class SingleOut(torch.autograd.Function):
20+
@staticmethod
21+
def forward(ctx, i):
22+
result = i.exp()
23+
result = result.log()
24+
ctx.save_for_backward(result)
25+
return result
26+
27+
@staticmethod
28+
def backward(ctx, grad_output):
29+
(result,) = ctx.saved_tensors
30+
return grad_output * result
31+
32+
class Caller(torch.nn.Module):
33+
def forward(self, input):
34+
result = input + 5
35+
return SingleOut.apply(result) + 3
36+
37+
model = Caller()
38+
input = torch.ones(1)
39+
run_model_test(self, model, input_args=(input,))
40+
41+
def test_multi_output(self):
42+
class MultiOut(torch.autograd.Function):
43+
@staticmethod
44+
def forward(ctx, i):
45+
result_exp = i.exp()
46+
result_log = result_exp.log()
47+
ctx.save_for_backward(result_exp, result_log)
48+
return result_exp, result_log
49+
50+
@staticmethod
51+
def backward(ctx, grad_output):
52+
(result,) = ctx.saved_tensors
53+
return grad_output * result
54+
55+
class Caller(torch.nn.Module):
56+
def forward(self, input):
57+
return MultiOut.apply(input)
58+
59+
model = Caller()
60+
input = torch.ones(1, 5)
61+
run_model_test(self, model, input_args=(input,))
62+
63+
def test_partial_output(self):
64+
class PartialOut(torch.autograd.Function):
65+
@staticmethod
66+
def forward(ctx, input):
67+
ctx.save_for_backward(input)
68+
values, indices = torch.topk(input, 3)
69+
return values
70+
71+
class Caller(torch.nn.Module):
72+
def forward(self, input):
73+
return PartialOut.apply(input)
74+
75+
model = Caller()
76+
input = torch.ones(1, 5)
77+
run_model_test(self, model, input_args=(input,))
78+
79+
def test_nested_autograd(self):
80+
class Child(torch.autograd.Function):
81+
@staticmethod
82+
def forward(ctx, i):
83+
result = i.log()
84+
result_log = result.log()
85+
ctx.save_for_backward(result_log)
86+
return result_log
87+
88+
@staticmethod
89+
def backward(ctx, grad_output):
90+
(result,) = ctx.saved_tensors
91+
return grad_output * result
92+
93+
class Parent(torch.autograd.Function):
94+
@staticmethod
95+
def forward(ctx, i):
96+
result_exp = i.exp()
97+
result_log = Child.apply(result_exp)
98+
ctx.save_for_backward(result_exp, result_log)
99+
return result_exp, result_log
100+
101+
@staticmethod
102+
def backward(ctx, grad_output):
103+
(result,) = ctx.saved_tensors
104+
return grad_output * result
105+
106+
class Caller(torch.nn.Module):
107+
def forward(self, input):
108+
return Parent.apply(input)
109+
110+
model = Caller()
111+
input = torch.ones(1, 5)
112+
run_model_test(self, model, input_args=(input,))
113+
114+
# Run export in ONNX_FALLTHROUGH mode as torch.erf() is not supported
115+
def test_aten_unsupported(self):
116+
class Erf(torch.autograd.Function):
117+
@staticmethod
118+
def forward(ctx, x):
119+
erf_out = torch.special.erf(x)
120+
ctx.save_for_backward(erf_out)
121+
return erf_out
122+
123+
@staticmethod
124+
def backward(ctx, grad_output):
125+
result = ctx.saved_tensors
126+
return torch.special.erfinv(result), None
127+
128+
class Caller(torch.nn.Module):
129+
def forward(self, input):
130+
return Erf.apply(input)
131+
132+
model = Caller()
133+
input = torch.ones(1, 5)
134+
135+
# Test ONNX_FALLTHROUGH_MODE
136+
graph, _, _ = _model_to_graph(
137+
model,
138+
(input,),
139+
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
140+
)
141+
iter = graph.nodes()
142+
self.assertEqual(next(iter).kind(), "prim::PythonOp")
143+
144+
# Test ATEN_FALLBACK_MODE
145+
graph, _, _ = _model_to_graph(
146+
model,
147+
(input,),
148+
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
149+
)
150+
iter = graph.nodes()
151+
self.assertEqual(next(iter).kind(), "prim::PythonOp")
152+
153+
def test_inline_and_symbolic(self):
154+
class Exp(torch.autograd.Function):
155+
@staticmethod
156+
def forward(ctx, i):
157+
ctx.save_for_backward(input)
158+
return i.exp()
159+
160+
@staticmethod
161+
def symbolic(g, input):
162+
return g.op("Exp", input)
163+
164+
class LogLog(torch.autograd.Function):
165+
@staticmethod
166+
def forward(ctx, i):
167+
ctx.save_for_backward(input)
168+
return i.log().log()
169+
170+
class Caller(torch.nn.Module):
171+
def forward(self, input):
172+
exp_result = Exp.apply(input)
173+
return LogLog.apply(exp_result)
174+
175+
model = Caller()
176+
input = torch.ones(1)
177+
run_model_test(self, model, input_args=(input,))
178+
179+
180+
if __name__ == "__main__":
181+
unittest.main()

test/onnx/test_utility_funs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,11 @@ def forward(self, input):
12311231
batch = torch.FloatTensor(1, 3)
12321232

12331233
graph, _, _ = self._model_to_graph(
1234-
model, batch, input_names=["batch"], dynamic_axes={"batch": [0, 1]}
1234+
model,
1235+
batch,
1236+
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1237+
input_names=["batch"],
1238+
dynamic_axes={"batch": [0, 1]},
12351239
)
12361240
iter = graph.nodes()
12371241
autograd1 = next(iter)

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Module) ->
335335
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
336336
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
337337
def _jit_pass_peephole(graph: Graph, disable_shape_peepholes: _bool = False) -> None: ...
338+
def _jit_pass_onnx_autograd_function_process(graph: Graph) -> None: ...
338339
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
339340
def _jit_pass_onnx_preprocess(graph: Graph) -> None: ...
340341
def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ...

torch/csrc/autograd/python_function.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,52 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) {
587587
return std::make_pair(std::move(unpacked), std::move(flags));
588588
}
589589

590+
// Given a prim::PythonOp node, _append_subgraph creates a subgraph such that:
591+
// (1) It has the same inputs as the prim::PythonOp node
592+
// (2) The intermediate nodes used in the PythonOp are cloned and stored in the
593+
// subgraph (3) trace_outputs stores the Value* objects, before a new trace
594+
// value is assigned by the prim::PythonOp node and helps to eventually route
595+
// the outputs of the subgraph correctly This newly created subgraph is then
596+
// added to the prim::PythonOp node as a subgraph attribute
597+
static void _append_subgraph(
598+
torch::jit::Node* node,
599+
torch::jit::Graph* graph,
600+
std::vector<torch::jit::Value*> trace_outputs,
601+
bool unpack_output) {
602+
node->g_(
603+
torch::jit::attr::Subgraph,
604+
std::make_shared<torch::jit::Graph>(graph->current_scope()));
605+
auto subgraph = node->g(torch::jit::attr::Subgraph);
606+
607+
std::unordered_map<Value*, Value*> value_map;
608+
auto value_map_func = [&](Value* v) { return value_map.at(v); };
609+
for (size_t i = 0; i < node->inputs().size(); ++i) {
610+
auto subgraph_input = subgraph->addInput();
611+
subgraph_input->copyMetadata(node->inputs().at(i));
612+
value_map[node->inputs().at(i)] = subgraph_input;
613+
}
614+
// Find node position in graph, all subsequent nodes after are added to
615+
// subgraph
616+
auto it = std::find(graph->nodes().begin(), graph->nodes().end(), node);
617+
// Skip TupleUnpack node if created
618+
if (!unpack_output) {
619+
it++;
620+
}
621+
for (it++; it != graph->nodes().end(); ++it) {
622+
torch::jit::Node* node = *it;
623+
auto* clone_node =
624+
subgraph->insertNode(subgraph->createClone(node, value_map_func));
625+
for (size_t i = 0; i < node->outputs().size(); ++i) {
626+
value_map[node->outputs()[i]] = clone_node->outputs()[i];
627+
auto trace_it = std::find(
628+
trace_outputs.begin(), trace_outputs.end(), node->outputs()[i]);
629+
if (trace_it != trace_outputs.end()) {
630+
subgraph->registerOutput(clone_node->outputs()[i]);
631+
}
632+
}
633+
}
634+
}
635+
590636
static torch::jit::Node* _trace_pre_record(
591637
PyObject* op_obj,
592638
PyObject* input_objects,
@@ -651,17 +697,26 @@ static void _trace_post_record(
651697
auto unpacked = graph->createTupleUnpack(node->output())->insertAfter(node);
652698
node = unpacked;
653699
}
700+
701+
std::vector<torch::jit::Value*> trace_outputs;
654702
for (const auto i : c10::irange(num_outputs)) {
655703
PyObject* obj = PyTuple_GET_ITEM(output_objects, i);
656704
if (THPVariable_Check(obj)) {
657705
Value* value = node->outputs()[i];
658706
const auto& tensor = THPVariable_Unpack(obj);
659707
if (tensor.defined()) {
660708
value->inferTypeFrom(tensor);
709+
trace_outputs.push_back(jit::tracer::getValueTrace(tensor));
661710
jit::tracer::setValueTrace(tensor, value);
662711
}
663712
}
664713
}
714+
py::bool_ is_in_onnx_export =
715+
py::module::import("torch.onnx.__init__").attr("is_in_onnx_export");
716+
if (py::cast<bool>(is_in_onnx_export)) {
717+
_append_subgraph(old_node, graph, trace_outputs, unpack_output);
718+
}
719+
665720
// If TupleUnpack operator is created, we copy its output type back
666721
// to the original tuple type.
667722
if (!unpack_output) {

torch/csrc/jit/ir/ir.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ std::ostream& Node::print(
318318
if (kind() == prim::PythonOp) {
319319
auto* pyOp = static_cast<const ::torch::jit::PythonOp*>(this);
320320
out << "^" << pyOp->name();
321+
printAttributes(out, /*ignore_subgraph=*/false);
321322
pyOp->writeScalars(out);
322323
} else if (hasAttribute(attr::Subgraph) && groups) {
323324
out << kind().toQualString() << "_" << groups->size();

torch/csrc/jit/passes/onnx.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,21 @@ void NodeToONNX(
365365
}
366366
};
367367

368+
// Inline the prim::PythonOp sub-block nodes and append them to the onnx graph
369+
auto inlineAutograd = [&](Node* PythonOpNode) {
370+
for (auto subblock : PythonOpNode->blocks()) {
371+
for (const auto i : c10::irange(PythonOpNode->inputs().size())) {
372+
env[subblock->inputs()[i]] = env[PythonOpNode->inputs()[i]];
373+
}
374+
for (auto* node : subblock->nodes()) {
375+
NodeToONNX(node, new_block, operator_export_type, env);
376+
}
377+
for (const auto i : c10::irange(PythonOpNode->outputs().size())) {
378+
env[PythonOpNode->outputs()[i]] = env[subblock->outputs()[i]];
379+
}
380+
}
381+
};
382+
368383
// Cast output of symbolic() python implementation
369384
auto processSymbolicOutput = [&](const std::string& op_name,
370385
Node* n,
@@ -441,11 +456,30 @@ void NodeToONNX(
441456
"PythonOp", "prim", opset_version);
442457
if (!py::hasattr(pyobj, "symbolic") &&
443458
(!PyObject_IsTrue(is_registered_op.ptr()))) {
444-
// Simply clone the node, unless either
459+
// Inline the subgraph within the prim::PythonOp unless
460+
// either of these conditions are satisfied
445461
// 1. The torch.autograd.Function class of this node object has `symbolic`
446462
// method defined.
447463
// 2. Custom export symbolic is registered for prim::PythonOp.
448-
cloneNode(op);
464+
if (operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX) {
465+
try {
466+
inlineAutograd(op);
467+
} catch (const std::exception& ex) {
468+
TORCH_WARN(
469+
"Unable to inline PythonOp: ",
470+
op->name(),
471+
" due to the following exception\n",
472+
ex.what(),
473+
"prim::PythonOp will be exported as is and without being inlined\n",
474+
"Try exporting with the following alternatives: \n",
475+
"1) Set operator_export_type to ONNX_FALLTHROUGH mode\n",
476+
"2) Register a symbolic method for the prim::PythonOp ",
477+
op->name());
478+
cloneNode(op);
479+
}
480+
} else {
481+
cloneNode(op);
482+
}
449483
return;
450484
}
451485

0 commit comments

Comments
 (0)
0