-
Notifications
You must be signed in to change notification settings - Fork 24.8k
[invoke_subgraph] Run missing graph passes recursively #152675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
57f8556
d786c21
e7fa1e3
3639e98
1ef794a
3c8fe85
f046bd8
c495166
19bbfc4
5f878b2
b4ae64d
df2583f
29555ce
0aa67bc
36d6535
bde719c
814e68f
cbf8ffc
fcdf745
239dd4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
from torch.fx._compatibility import compatibility | ||
from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake | ||
from torch.fx.node import map_aggregate | ||
from torch.utils._ordered_set import OrderedSet | ||
|
||
|
||
__all__ = ["FakeTensorProp"] | ||
|
@@ -36,13 +37,41 @@ def __init__( | |
self._mode = mode | ||
mode.epoch += 1 | ||
mode.reset_nt_tensor_id_counter() | ||
self.seen_subgraphs: OrderedSet[str] = OrderedSet() | ||
|
||
def run_node(self, n: Node): | ||
from torch.fx.experimental.symbolic_shapes import ( | ||
compute_unbacked_bindings, | ||
rebind_unbacked, | ||
8000 | ) | |
|
||
if ( | ||
n.op == "call_function" | ||
and n.target is torch.ops.higher_order.invoke_subgraph | ||
and n.args[1] not in self.seen_subgraphs | ||
): | ||
# Prevent redundant fake tensor prop for invoke_subgraphs. Note that | ||
# there is also fake tensor caching for the entire subgraph. This | ||
# happens the next time we call `run_node` for the same subgraph, | ||
# which goes through super.run_node and caches the fake tensor prop. | ||
# Therefore, we are propagating fake tensor through the subgraphs | ||
# twice. | ||
assert isinstance(n.args[1], str) | ||
assert ( | ||
isinstance(n.args[0], torch.fx.Node) | ||
and n.args[0].op == "get_attr" | ||
and isinstance(n.args[0].target, str) | ||
) | ||
self.seen_subgraphs.add(n.args[1]) | ||
operands = n.args[2:] | ||
example_inputs = [] | ||
for operand in operands: | ||
assert isinstance(operand, torch.fx.Node) and "val" in operand.meta | ||
example_inputs.append(operand.meta["val"]) | ||
return FakeTensorProp( | ||
getattr(self.module, n.args[0].target), mode=self._mode | ||
).propagate(*example_inputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This reminds me - inductor also has a way of doing incremental faketensor updates through There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked that one. So anything which is run during post_grad pass, we will recurse. view_to_reshape is outside of post_grad. Therefore, we need this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FakeTensorUpdater doesn't work on HOPs, we should fix that at some point |
||
|
||
result = super().run_node(n) | ||
rebind_unbacked(self._mode.shape_env, n, result) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer some automated way of doing this, like GraphTransformObserver accepts
view_to_reshape
and then automatically applies the pass to subgraphs.How many more times do we need to do this manually?