-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[invoke_subgraph] Force the output stride to be same as eager #152806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/anijain2305/753/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152806
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 3 Unrelated FailuresAs of commit ba8920b with merge base fdadda2 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ger" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm but i'll wait for someone from inductor to review
…ger" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
…ger" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
…ger" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
…ger" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
example_stride = handle_sym_expr(fake_outputs[idx].stride()) | ||
new_outputs.append(cls.require_exact_strides(output, example_stride)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is right. Can Inductor passes change the fake_outputs in a way that they differ from eager?
If so we need to record the meta vals at the time of tracing, before passes run, and then use the metadata on them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this applies to inputs of the invoke subgraph then as well. Currently, we rely on the meta vals of the inputs of invoke subgraph, which could be different from eager because of graph passes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remind me why we want to force the inputs and output strides to be the same as eager? If we were not doing invoke_subgraph, inductor is allowed to change intermediates in the graph to have whatever strides it wants, with some exceptions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to reduce compile time. We compile a subgraph once and then call the same subgraph output code on second call. Since the input strides can be different for different subgraph calls, we restride the input to a fixed value at the beginning of each subgraph.
This allows us to reuse the output code of a subgraph. This is very important for compile time, otherwise the major benefits of invoke subgraph are not realized.
It is possible that the restriding is not to eager strides but to some strides after inductor graph passes are run. Nevertheless, it's a fixed and valid input strides.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have some infrastructure to do this already (for inputs), check out
pytorch/torch/fx/experimental/proxy_tensor.py
Lines 1127 to 1134 in bc11afd
if _should_save_eager_input_vals(target, (args, kwargs)): | |
# NOTE "eager_input_vals" | |
# We save the original (args, kwargs) FakeTensor values for nodes | |
# that have exact stride requirements. This is useful downstream. | |
# We use this information inside Inductor to ensure that inputs to | |
# stride-sensitive operators have the correct strides. | |
arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type] | |
node.meta["eager_input_vals"] = (arg_inp, kwarg_inp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea - let's use the above mechanism
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can use this for input
. Is there anything for the output
strides? The pointer is only for the inputs, but I also want to constrain the outputs.
@@ -7515,6 +7519,17 @@ def create_output(output: IRNode, ind: int): | |||
skip_size_stride_alignment_checks=True, | |||
) | |||
|
|||
# Force the output strides to be same as the original strides |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs a test at the very least. You can add an invoke_subgraph node, then do a graph pass that changes the outputs in the invoke_subgraph subgraph, and then check to make sure the strides are still what you expect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was not able to get a test working.
I was looking at a regression when I wrap the whole model with the invoke subgraph. When I diffed the output code, I saw an extra kernel after the invoke subgraph call, even though there was no operation outside of the invoke subgraph call. So this PR was my attempt to make the stride of invoke subgraph same as eager output to avoid that extra kernel. This fixed the regression. But after your comment about passes changing meta vals, I am not sure if this is correct (or what should be the solution to avoid the extra kernel)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me know when you want review ! stride issues can be very tricky and it's worth pushing on this imo.
Definitely on my todo list .. just need time to understand the inductor codebase more to do this. |
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov