8000 Improvements for associative_scan - slicing of xs by bohnstingl · Pull Request #138858 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Improvements for associative_scan - slicing of xs #138858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from

Conversation

bohnstingl
Copy link
Collaborator
@bohnstingl bohnstingl commented Oct 24, 2024

@bohnstingl bohnstingl requested a review from zou3519 as a code owner October 24, 2024 22:21
Copy link
pytorch-bot bot commented Oct 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138858

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 9e01fff with merge base fb36daa (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor
@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall.

  1. Can we clean up the comments in test files so that all of them are up-to-date.
  2. The tests have similar structure, can we come up with a _run_test method like what we did in test/inductor/test_control_flow.py? And it's ok to make exceptions for non-standard tests like those test raising errors. This would make things much more cleaner and easy to maintain/add new tests.

if len(leaves) != len(out_leaves):
raise RuntimeError(
"The number of leaves of the pytree of the output of the operator needs to match the length of the pytree of the input"
)
if any(x.shape != shape for x in out_leaves):
if any(x.shape != sliced_shape for x in out_leaves):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably should also check strides, require_grads, device, dtype

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, I added checks for strides, device and dtype. The checks for requires_grad will be added in the PR that contains the Autograd

reverse=reverse,
)

self.assertEqual(result[1], expected_result[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens for result[0]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I only checked the result[1] in that version, but in the new commit, all results are checked

@bohnstingl
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Oct 25, 2024
@bohnstingl
Copy link
Collaborator Author

Thank you @ydwu4 for the review. I integrated the comments placed here.
In particular, the new test implementation with the _run_tests function resulted in quite a bit of code de-duplication and indeed the tests are much cleaner now.

@bohnstingl bohnstingl requested a review from ydwu4 October 26, 2024 23:05
or x.device != x_sliced.device
or x.stride() != x_sliced.stride()
for x, x_sliced in zip(out_leaves, sliced_leaves)
):
raise RuntimeError(
"The pytree of the output of the operator needs to match the xs pytree"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message should make it clear what's not matched. A native way is we could split them into 4 ifs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced the generic RuntimeError with an Error that provides details about the metadata of the tensors. In particular,

raise RuntimeError(
            f"The metadata of the output of the operator needs to match the meta data of the xs pytree"
            f"\n  xs metadata             : {[(x.shape, x.dtype, x.device, x.stride()) for x in sliced_leaves]}"
            f"\n  operator output metadata: {[(x.shape, x.dtype, x.device, x.stride()) for x in out_leaves]}"
        )

# Therefore, the paralellization is realized with vmap on `dim`
combine_fn = functools.partial(
wrap_combine_fn_flat,
combine_fn=torch.vmap(combine_fn, dim, dim),
Copy link
Contributor
6D40

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, does this also vmap over the additional inputs (we shouldn't, right?).

Copy link
Collaborator Author
@bohnstingl bohnstingl Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we cannot. This is actually a problem that we need to tackle once additional inputs are supported. Do you have a better idea?

Copy link
Contributor
@ydwu4 ydwu4 Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could set the in_dim of the additional_inputs to be None.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I will add a TODO for me and keep it in mind for the Autograd implementation

if combine_mode == "generic":
# The generic_associative_scan implementation calls the combine_fn with a batch long the scan dimension
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main difficult part for me is to understand how vmap interacts with the recursive call in generic_associative_scan. Maybe can use a 2-D or 3-D tensor inputs to illustrate how it works.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an example for clarification

torch._dynamo.reset()
super().setUp()

def _run_test(self, model, inputs, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow I feel the interface of _run_test difficult to understand. Specifically, I don't feel comfortable when copying existing tests.

Probably one thing we can do is to write down the kwargs explicitly: e.g., make combine_mode, dim, reverse, compile_mode, combine_fn, as explicit kwargs. Should they just be initialized in the model we're going to run?

We can remove the fake_combine_fn and make that test just be a standalone test. Because I don't know when should I provide one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed offline about the interface change of _run_test and I incorporated it accordingly. Let me know what you think.

# @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
# @parametrize("reverse", [False, True])
# @parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@unittest.expectedFailure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a TODO here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an issue with using map inside the associative_scan. We discussed this offline and I left a TODO

@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 28, 2024
@zou3519 zou3519 removed their request for review October 29, 2024 20:53
@bohnstingl bohnstingl requested a review from ydwu4 October 30, 2024 21:23
Copy link
Contributor
@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! Left a few comments. wait for ci to pass.

@ydwu4
Copy link
Contributor
ydwu4 commented Nov 5, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 5, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

atalman pushed a commit to atalman/pytorch that referenced this pull request Nov 11, 2024
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966

Pull Request resolved: pytorch#138858
Approved by: https://github.com/ydwu4
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966

Pull Request resolved: pytorch#138858
Approved by: https://github.com/ydwu4
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966

Pull Request resolved: pytorch#138858
Approved by: https://github.com/ydwu4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0