|
2 | 2 | r'''
|
3 | 3 | **This feature is under a Beta release and its API may change.**
|
4 | 4 |
|
5 |
| -FX is a toolkit for capturing and transforming functional PyTorch programs. It |
6 |
| -consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed |
7 |
| -with an ``nn.Module`` instance as its argument, GraphModule will trace through the computation of that Module's |
8 |
| -``forward`` method symbolically and record those operations in the FX intermediate representation. |
| 5 | +FX is a toolkit for developers to use to transform ``nn.Module`` |
| 6 | +instances. FX consists of three main components: a **symbolic tracer,** |
| 7 | +an **intermediate representation**, and **Python code generation**. A |
| 8 | +demonstration of these components in action: |
9 | 9 |
|
10 |
| -.. code-block:: python |
| 10 | +:: |
11 | 11 |
|
12 | 12 | import torch
|
13 |
| - import torch.fx |
14 |
| -
|
| 13 | + # Simple module for demonstration |
15 | 14 | class MyModule(torch.nn.Module):
|
16 | 15 | def __init__(self):
|
17 | 16 | super().__init__()
|
18 | 17 | self.param = torch.nn.Parameter(torch.rand(3, 4))
|
19 | 18 | self.linear = torch.nn.Linear(4, 5)
|
20 | 19 |
|
21 |
| - def forward(self, x): |
22 |
| - return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) |
23 |
| -
|
24 |
| - m = MyModule() |
25 |
| - gm = torch.fx.symbolic_trace(m) |
26 |
| -
|
27 |
| -The Intermediate Representation centers around a 5-opcode format:: |
| 20 | + def forward(self, x): |
| 21 | + return self.linear(x + self.param).clamp(min=0.0, max=1.0) |
28 | 22 |
|
29 |
| - print(gm.graph) |
| 23 | + module = MyModule() |
30 | 24 |
|
31 |
| -.. code-block:: text |
| 25 | + from torch.fx import symbolic_trace |
| 26 | + # Symbolic tracing frontend - captures the semantics of the module |
| 27 | + symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) |
32 | 28 |
|
| 29 | + # High-level intermediate representation (IR) - Graph representation |
| 30 | + print(symbolic_traced.graph) |
| 31 | + """ |
33 | 32 | graph(x):
|
34 |
| - %linear_weight : [#users=1] = self.linear.weight |
35 |
| - %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) |
| 33 | + %param : [#users=1] = self.param |
| 34 | + %add_1 : [#users=1] = call_function[target=<built-in function add>](args = (%x, %param), kwargs = {}) |
36 | 35 | %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
|
37 |
| - %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) |
38 |
| - %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) |
39 |
| - %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) |
40 |
| - return topk_1 |
41 |
| -
|
42 |
| -The Node semantics are as follows: |
43 |
| -
|
44 |
| -- ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. |
45 |
| - ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument |
46 |
| - denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to |
47 |
| - the function parameters (e.g. ``x``) in the graph printout. |
48 |
| -- ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the |
49 |
| - fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. |
50 |
| - ``args`` and ``kwargs`` are don't-care |
51 |
| -- ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign |
52 |
| - to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, |
53 |
| - following the Python calling convention |
54 |
| -- ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is |
55 |
| - as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. |
56 |
| - ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. |
57 |
| -- ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method |
58 |
| - to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, |
59 |
| - *including the self argument* |
60 |
| -- ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement |
61 |
| - in the Graph printout. |
| 36 | + %clamp_1 : [#users=1] = call_method[target=clamp](args = (%linear_1,), kwargs = {min: 0.0, max: 1.0}) |
| 37 | + return clamp_1 |
| 38 | + """ |
62 | 39 |
|
63 |
| -GraphModule automatically generates Python code for the operations it symbolically observed:: |
64 |
| -
|
65 |
| - print(gm.code) |
66 |
| -
|
67 |
| -.. code-block:: python |
68 |
| -
|
69 |
| - import torch |
| 40 | + # Code generation - valid Python code |
| 41 | + print(symbolic_traced.code) |
| 42 | + """ |
70 | 43 | def forward(self, x):
|
71 |
| - linear_weight = self.linear.weight |
72 |
| - add_1 = x + linear_weight; x = linear_weight = None |
| 44 | + param = self.param |
| 45 | + add_1 = x + param; x = param = None |
73 | 46 | linear_1 = self.linear(add_1); add_1 = None
|
74 |
| - relu_1 = linear_1.relu(); linear_1 = None |
75 |
| - sum_1 = torch.sum(relu_1, dim = -1); relu_1 = None |
76 |
| - topk_1 = torch.topk(sum_1, 3); sum_1 = None |
77 |
| - return topk_1 |
78 |
| -
|
79 |
| -Because this code is valid PyTorch code, the resulting ``GraphModule`` can be used in any context another |
80 |
| -``nn.Module`` can be used, including in TorchScript tracing/compilation. |
| 47 | + clamp_1 = linear_1.clamp(min = 0.0, max = 1.0); linear_1 = None |
| 48 | + return clamp_1 |
| 49 | + """ |
| 50 | +
|
| 51 | +The **symbolic tracer** performs “abstract interpretation” of the Python |
| 52 | +code. It feeds fake values, called Proxies, through the code. Operations |
| 53 | +on theses Proxies are recorded. More information about symbolic tracing |
| 54 | +can be found in the |
| 55 | +`symbolic\_trace <https://pytorch.org/docs/master/fx.html#torch.fx.symbolic_trace>`__ |
| 56 | +and `Tracer <https://pytorch.org/docs/master/fx.html#torch.fx.Tracer>`__ |
| 57 | +documentation. |
| 58 | +
|
| 59 | +The **intermediate representation** is the container for the operations |
| 60 | +that were recorded during symbolic tracing. It consists of a list of |
| 61 | +Nodes that represent function inputs, callsites (to functions, methods, |
| 62 | +or ``nn.Module`` instances), and return values. More information about |
| 63 | +the IR can be found in the documentation for |
| 64 | +`Graph <https://pytorch.org/docs/master/fx.html#torch.fx.Graph>`__. The |
| 65 | +IR is the format on which transformations are applied. |
| 66 | +
|
| 67 | +**Python code generation** is what makes FX a Python-to-Python (or |
| 68 | +Module-to-Module) transformation toolkit. For each Graph IR, we can |
| 69 | +create valid Python code matching the Graph’s semantics. This |
| 70 | +functionality is wrapped up in |
| 71 | +`GraphModule <https://pytorch.org/docs/master/fx.html#torch.fx.GraphModule>`__, |
| 72 | +which is an ``nn.Module`` instance that holds a ``Graph`` as well as a |
| 73 |
60E0
+``forward`` method generated from the Graph. |
| 74 | +
|
| 75 | +Taken together, this pipeline of components (symbolic tracing → |
| 76 | +intermediate representation → transforms → Python code generation) |
| 77 | +constitutes the Python-to-Python transformation pipeline of FX. |
81 | 78 | '''
|
82 | 79 |
|
83 | 80 | from .graph_module import GraphModule
|
|
0 commit comments