-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Improvements for associative_scan - Autograd #136966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136966
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 15 New FailuresAs of commit fc19798 with merge base a16476b ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
a9366e6
to
3ab2330
Compare
@pytorchbot label "topic: not user facing" |
15de9c4
to
8d669c1
Compare
Binary operator not working yet
*) Added more documentation *) Create function to compute gradient for one leaf and vmapped to compute all in parallel
8d669c1
to
fc19798
Compare
|
||
results = [ | ||
torch.stack([e[leave_ind] for e in op(result_flat)], dim) | ||
torch.concatenate([e[leave_ind] for e in op(result_flat)], dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it a concatenate? if we associative_scan over (4, 2, 3) over dim=0, each subgraph should work on a slice of (2, 3), and the end results should be of shape (4, 2, 3). Anything wrong with this interface? After the change, does subgraph takes (1, 2, 3) or result becomes (8, 3).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, this one should be deleted i feel? _fake_scan is now in _higher_order_ops/scan.py.
from .builder import wrap_fx_proxy | ||
|
||
args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) | ||
|
||
def arg_extractor(combine_fn, xs, dim): | ||
return combine_fn, xs, dim | ||
def arg_extractor(combine_fn, xs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i feel we should split this diff into 2, first is the the dim change, then the autograd.
dim = utils.canonicalize_dim(ndim, dim) | ||
# Move scan dim to 0 and always perform scan on dim 0 | ||
orig_scan_dim = dim | ||
leaves = [shift_source_dim_to_target_dim(elem, int(dim), 0) for elem in leaves] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might replace shift_source_dim_to_target_dim
with torch.movedim(elem, int(dim), 0).
result_flat = [torch.flip(elem, [0]) for elem in result_flat] | ||
|
||
result_flat = [ | ||
shift_source_dim_to_target_dim(elem, 0, orig_scan_dim) for elem in result_flat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shift_source_dim_to_target_dim -> movedim
|
||
return pytree.tree_unflatten(result_flat, spec) | ||
|
||
|
||
# TODO: Provide inductor support for generic scan |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's missing in inductor for generic scan. Is it the test failure we talked about?
return (*outs,) | ||
|
||
@staticmethod | ||
def backward(ctx, *flat_grads_unmasked): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I trust you on this.
Didn't look into the details of the backward implementation. Some general thoughts for better testing this: can we use scan to implement a baseline version first and add much more tests to verify the correctness (e.g. nesting cond, scan and associative scan with autograd, more types of ops inside the body of associative scan(e.g. different kinds of view ops, non-continous inputs and outputs.)
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of #136966 Pull Request resolved: #138858 Approved by: https://github.com/ydwu4
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966 Pull Request resolved: pytorch#138858 Approved by: https://github.com/ydwu4
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966 Pull Request resolved: pytorch#138858 Approved by: https://github.com/ydwu4
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966 Pull Request resolved: pytorch#138858 Approved by: https://github.com/ydwu4
This is part of a series of PRs to improve the functionality of the associatve_scan functionality. This specific PR implements the Autograd for associative_scan. This PR has been derived from #129307.
@ydwu4
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec