8000 [FX] Update overview docstring (#50896) · pytorch/pytorch@5016637 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5016637

Browse files
James Reedfacebook-github-bot
James Reed
authored andcommitted
[FX] Update overview docstring (#50896)
Summary: Pull Request resolved: #50896 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D26002067 Pulled By: jamesr66a fbshipit-source-id: 3b4d4b96017d16739a31f25a306f55b6f96324dc
1 parent eb0fe70 commit 5016637

File tree

1 file changed

+56
-59
lines changed

1 file changed

+56
-59
lines changed

torch/fx/__init__.py

Lines changed: 56 additions & 59 deletions
60E0
Original file line numberDiff line numberDiff line change
@@ -2,82 +2,79 @@
22
r'''
33
**This feature is under a Beta release and its API may change.**
44
5-
FX is a toolkit for capturing and transforming functional PyTorch programs. It
6-
consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed
7-
with an ``nn.Module`` instance as its argument, GraphModule will trace through the computation of that Module's
8-
``forward`` method symbolically and record those operations in the FX intermediate representation.
5+
FX is a toolkit for developers to use to transform ``nn.Module``
6+
instances. FX consists of three main components: a **symbolic tracer,**
7+
an **intermediate representation**, and **Python code generation**. A
8+
demonstration of these components in action:
99
10-
.. code-block:: python
10+
::
1111
1212
import torch
13-
import torch.fx
14-
13+
# Simple module for demonstration
1514
class MyModule(torch.nn.Module):
1615
def __init__(self):
1716
super().__init__()
1817
self.param = torch.nn.Parameter(torch.rand(3, 4))
1918
self.linear = torch.nn.Linear(4, 5)
2019
21-
def forward(self, x):
22-
return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
23-
24-
m = MyModule()
25-
gm = torch.fx.symbolic_trace(m)
26-
27-
The Intermediate Representation centers around a 5-opcode format::
20+
def forward(self, x):
21+
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
2822
29-
print(gm.graph)
23+
module = MyModule()
3024
31-
.. code-block:: text
25+
from torch.fx import symbolic_trace
26+
# Symbolic tracing frontend - captures the semantics of the module
27+
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
3228
29+
# High-level intermediate representation (IR) - Graph representation
30+
print(symbolic_traced.graph)
31+
"""
3332
graph(x):
34-
%linear_weight : [#users=1] = self.linear.weight
35-
%add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
33+
%param : [#users=1] = self.param
34+
%add_1 : [#users=1] = call_function[target=<built-in function add>](args = (%x, %param), kwargs = {})
3635
%linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
37-
%relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
38-
%sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
39-
%topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
40-
return topk_1
41-
42-
The Node semantics are as follows:
43-
44-
- ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
45-
``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
46-
denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
47-
the function parameters (e.g. ``x``) in the graph printout.
48-
- ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
49-
fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
50-
``args`` and ``kwargs`` are don't-care
51-
- ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
52-
to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
53-
following the Python calling convention
54-
- ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
55-
as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
56-
``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*.
57-
- ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
58-
to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
59-
*including the self argument*
60-
- ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
61-
in the Graph printout.
36+
%clamp_1 : [#users=1] = call_method[target=clamp](args = (%linear_1,), kwargs = {min: 0.0, max: 1.0})
37+
return clamp_1
38+
"""
6239
63-
GraphModule automatically generates Python code for the operations it symbolically observed::
64-
65-
print(gm.code)
66-
67-
.. code-block:: python
68-
69-
import torch
40+
# Code generation - valid Python code
41+
print(symbolic_traced.code)
42+
"""
7043
def forward(self, x):
71-
linear_weight = self.linear.weight
72-
add_1 = x + linear_weight; x = linear_weight = None
44+
param = self.param
45+
add_1 = x + param; x = param = None
7346
linear_1 = self.linear(add_1); add_1 = None
74-
relu_1 = linear_1.relu(); linear_1 = None
75-
sum_1 = torch.sum(relu_1, dim = -1); relu_1 = None
76-
topk_1 = torch.topk(sum_1, 3); sum_1 = None
77-
return topk_1
78-
79-
Because this code is valid PyTorch code, the resulting ``GraphModule`` can be used in any context another
80-
``nn.Module`` can be used, including in TorchScript tracing/compilation.
47+
clamp_1 = linear_1.clamp(min = 0.0, max = 1.0); linear_1 = None
48+
return clamp_1
49+
"""
50+
51+
The **symbolic tracer** performs “abstract interpretation” of the Python
52+
code. It feeds fake values, called Proxies, through the code. Operations
53+
on theses Proxies are recorded. More information about symbolic tracing
54+
can be found in the
55+
`symbolic\_trace <https://pytorch.org/docs/master/fx.html#torch.fx.symbolic_trace>`__
56+
and `Tracer <https://pytorch.org/docs/master/fx.html#torch.fx.Tracer>`__
57+
documentation.
58+
59+
The **intermediate representation** is the container for the operations
60+
that were recorded during symbolic tracing. It consists of a list of
61+
Nodes that represent function inputs, callsites (to functions, methods,
62+
or ``nn.Module`` instances), and return values. More information about
63+
the IR can be found in the documentation for
64+
`Graph <https://pytorch.org/docs/master/fx.html#torch.fx.Graph>`__. The
65+
IR is the format on which transformations are applied.
66+
67+
**Python code generation** is what makes FX a Python-to-Python (or
68+
Module-to-Module) transformation toolkit. For each Graph IR, we can
69+
create valid Python code matching the Graph’s semantics. This
70+
functionality is wrapped up in
71+
`GraphModule <https://pytorch.org/docs/master/fx.html#torch.fx.GraphModule>`__,
72+
which is an ``nn.Module`` instance that holds a ``Graph`` as well as a
73+
``forward`` method generated from the Graph.
74+
75+
Taken together, this pipeline of components (symbolic tracing →
76+
intermediate representation → transforms → Python code generation)
77+
constitutes the Python-to-Python transformation pipeline of FX.
8178
'''
8279

8380
from .graph_module import GraphModule

0 commit comments

Comments
 (0)
0