8000 Improvements for associative_scan - Autograd by bohnstingl · Pull Request #136966 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Improvements for associative_scan - Autograd #136966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from

Conversation

bohnstingl
Copy link
Collaborator
@bohnstingl bohnstingl commented Sep 29, 2024

This is part of a series of PRs to improve the functionality of the associatve_scan functionality. This specific PR implements the Autograd for associative_scan. This PR has been derived from #129307.

@ydwu4

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec

@bohnstingl bohnstingl requested a review from zou3519 as a code owner September 29, 2024 22:42
Copy link
pytorch-bot bot commented Sep 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136966

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 15 New Failures

As of commit fc19798 with merge base a16476b (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zou3519 zou3519 requested review from ydwu4 and removed request for zou3519 September 30, 2024 19:48
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 30, 2024
@bohnstingl bohnstingl force-pushed the generic_associative_scan_6 branch from a9366e6 to 3ab2330 Compare October 2, 2024 08:37
@bohnstingl
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Oct 17, 2024
@bohnstingl bohnstingl force-pushed the generic_associative_scan_6 branch from 15de9c4 to 8d669c1 Compare October 18, 2024 17:47
@bohnstingl bohnstingl force-pushed the generic_associative_scan_6 branch from 8d669c1 to fc19798 Compare October 22, 2024 23:11

results = [
torch.stack([e[leave_ind] for e in op(result_flat)], dim)
torch.concatenate([e[leave_ind] for e in op(result_flat)], dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it a concatenate? if we associative_scan over (4, 2, 3) over dim=0, each subgraph should work on a slice of (2, 3), and the end results should be of shape (4, 2, 3). Anything wrong with this interface? After the change, does subgraph takes (1, 2, 3) or result becomes (8, 3).

Copy link
Contributor
@ydwu4 ydwu4 Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, this one should be deleted i feel? _fake_scan is now in _higher_order_ops/scan.py.

from .builder import wrap_fx_proxy

args, kwargs = LazyVariableTracker.realize_all((args, kwargs))

def arg_extractor(combine_fn, xs, dim):
return combine_fn, xs, dim
def arg_extractor(combine_fn, xs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i feel we should split this diff into 2, first is the the dim change, then the autograd.

dim = utils.canonicalize_dim(ndim, dim)
# Move scan dim to 0 and always perform scan on dim 0
orig_scan_dim = dim
leaves = [shift_source_dim_to_target_dim(elem, int(dim), 0) for elem in leaves]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might replace shift_source_dim_to_target_dim with torch.movedim(elem, int(dim), 0).

result_flat = [torch.flip(elem, [0]) for elem in result_flat]

result_flat = [
shift_source_dim_to_target_dim(elem, 0, orig_scan_dim) for elem in result_flat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shift_source_dim_to_target_dim -> movedim


return pytree.tree_unflatten(result_flat, spec)


# TODO: Provide inductor support for generic scan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's missing in inductor for generic scan. Is it the test failure we talked about?

return (*outs,)

@staticmethod
def backward(ctx, *flat_grads_unmasked):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I trust you on this.

Didn't look into the details of the backward implementation. Some general thoughts for better testing this: can we use scan to implement a baseline version first and add much more tests to verify the correctness (e.g. nesting cond, scan and associative scan with autograd, more types of ops inside the body of associative scan(e.g. different kinds of view ops, non-continous inputs and outputs.)

pytorchmergebot pushed a commit that referenced this pull request Nov 5, 2024
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of #136966

Pull Request resolved: #138858
Approved by: https://github.com/ydwu4
@bohnstingl
Copy link
Collaborator Author

Closing this PR, as it is split into several smaller PRs: #138858, #139864, #139939

@bohnstingl bohnstingl closed this Nov 6, 2024
atalman pushed a commit to atalman/pytorch that referenced this pull request Nov 11, 2024
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966

Pull Request resolved: pytorch#138858
Approved by: https://github.com/ydwu4
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966

Pull Request resolved: pytorch#138858
Approved by: https://github.com/ydwu4
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966

Pull Request resolved: pytorch#138858
Approved by: https://github.com/ydwu4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0