-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Improvements for associative_scan - slicing of xs #138858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138858
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 9e01fff with merge base fb36daa ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall.
- Can we clean up the comments in test files so that all of them are up-to-date.
- The tests have similar structure, can we come up with a _run_test method like what we did in test/inductor/test_control_flow.py? And it's ok to make exceptions for non-standard tests like those test raising errors. This would make things much more cleaner and easy to maintain/add new tests.
if len(leaves) != len(out_leaves): | ||
raise RuntimeError( | ||
"The number of leaves of the pytree of the output of the operator needs to match the length of the pytree of the input" | ||
) | ||
if any(x.shape != shape for x in out_leaves): | ||
if any(x.shape != sliced_shape for x in out_leaves): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably should also check strides, require_grads, device, dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, I added checks for strides, device and dtype. The checks for requires_grad will be added in the PR that contains the Autograd
test/functorch/test_control_flow.py
Outdated
reverse=reverse, | ||
) | ||
|
||
self.assertEqual(result[1], expected_result[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens for result[0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, I only checked the result[1] in that version, but in the new commit, all results are checked
@pytorchbot label "topic: not user facing" |
89aa61e
to
59b164b
Compare
Thank you @ydwu4 for the review. I integrated the comments placed here. |
or x.device != x_sliced.device | ||
or x.stride() != x_sliced.stride() | ||
for x, x_sliced in zip(out_leaves, sliced_leaves) | ||
): | ||
raise RuntimeError( | ||
"The pytree of the output of the operator needs to match the xs pytree" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message should make it clear what's not matched. A native way is we could split them into 4 ifs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced the generic RuntimeError with an Error that provides details about the metadata of the tensors. In particular,
raise RuntimeError(
f"The metadata of the output of the operator needs to match the meta data of the xs pytree"
f"\n xs metadata : {[(x.shape, x.dtype, x.device, x.stride()) for x in sliced_leaves]}"
f"\n operator output metadata: {[(x.shape, x.dtype, x.device, x.stride()) for x in out_leaves]}"
)
# Therefore, the paralellization is realized with vmap on `dim` | ||
combine_fn = functools.partial( | ||
wrap_combine_fn_flat, | ||
combine_fn=torch.vmap(combine_fn, dim, dim), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, does this also vmap over the additional inputs (we shouldn't, right?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we cannot. This is actually a problem that we need to tackle once additional inputs are supported. Do you have a better idea?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could set the in_dim of the additional_inputs to be None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, I will add a TODO for me and keep it in mind for the Autograd implementation
if combine_mode == "generic": | ||
# The generic_associative_scan implementation calls the combine_fn with a batch long the scan dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main difficult part for me is to understand how vmap interacts with the recursive call in generic_associative_scan. Maybe can use a 2-D or 3-D tensor inputs to illustrate how it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added an example for clarification
test/functorch/test_control_flow.py
Outdated
torch._dynamo.reset() | ||
super().setUp() | ||
|
||
def _run_test(self, model, inputs, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow I feel the interface of _run_test difficult to understand. Specifically, I don't feel comfortable when copying existing tests.
Probably one thing we can do is to write down the kwargs explicitly: e.g., make combine_mode, dim, reverse, compile_mode, combine_fn, as explicit kwargs. Should they just be initialized in the model we're going to run?
We can remove the fake_combine_fn and make that test just be a standalone test. Because I don't know when should I provide one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We discussed offline about the interface change of _run_test
and I incorporated it accordingly. Let me know what you think.
# @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"]) | ||
# @parametrize("reverse", [False, True]) | ||
# @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) | ||
@unittest.expectedFailure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a TODO here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an issue with using map inside the associative_scan. We discussed this offline and I left a TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Left a few comments. wait for ci to pass.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966 Pull Request resolved: pytorch#138858 Approved by: https://github.com/ydwu4
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966 Pull Request resolved: pytorch#138858 Approved by: https://github.com/ydwu4
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966 Pull Request resolved: pytorch#138858 Approved by: https://github.com/ydwu4
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of #136966
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @rec @ydwu4