-
Notifications
You must be s 8000 igned in to change notification settings - Fork 24.2k
Parallel Associative Scan #95408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I think |
Some more context This method would be incredibly useful for training a class of modern recurrent networks based on linear state-space models, that was able to achieve state-of-the-art results on long-sequence prediction tasks, e.g. the long range arena. More details are available in Appendix H of this paper which used the jax associative_scan method to train it originally. |
@abdulfatir @PeaBrane I think this is what was used in TF & JAX: https://github.com/eamartin/parallelizing_linear_rnns |
this would indeed be very useful (ex. https://arxiv.org/abs/2305.13048) |
@abdulfatir
and the times are
Can the associative scan from triton be realized in PyTorch with compilation? |
yeah that's the plan @bohnstingl :) |
Is there any eta on when this will be available? I have some torch code that requires the associative scan, and I'm deciding whether to rewrite it in jax or wait for a torch associative scan. |
Not sure if this helps but I ported the It only needs the pytree things to be converted to the pytorch internal ones to lose the jax dependency. |
@i404788 that's excellent! Do you have some benchmarks for your |
@harpone I didn't benchmark it but I've adapted @bohnstingl's script (torch.compile commented because my gpu is too old 😅 ):
Output (n_timings=100, seq_len=1024, batch=1, dim=1):
Output(n_timings=100, seq_len=1024, batch=4, dim=1):
Someone with more VRAM should probably test it with different configs, since the time doesn't seem to change much between configs. |
In case if anyone is interested in an implementation of mamba selective scan, but without using parallel scan, there is a way to do it with two cumsums. I made a fork from mamba_minimal, and implemented my method in this commit. Based on the texts generated by the demo notebook, it seems to be functional. The core code is just this: def selective_scan(self, u, dt, A, B, C, D):
dA = torch.einsum('bld,dn->bldn', dt, A)
dB_u = torch.einsum('bld,bld,bln->bldn', dt, u, B)
dA_cumsum = F.pad(dA[:, 1:], (0, 0, 0, 0, 0, 1)).flip(1).cumsum(1).exp().flip(1)
x = dB_u * dA_cumsum
x = x.cumsum(1) / (dA_cumsum + 1e-12)
y = torch.einsum('bldn,bln->bld', x, C)
return y + u * D Even though the implementation may not be optimal, it should be somewhat comparable to the original implementation, assuming that
Edit: I realized this is potential just heisen_sequence in non-log space, which is perhaps also related to @Algomancer's approach. |
@PeaBrane do you have sample benchmarking code? |
@Chillee I just went into the mamba_minimal and mamba_tiny repos separately and ran the following script import time
import numpy as np
import torch
from model import Mamba
pretrained_model_name = 'state-spaces/mamba-370m'
model = Mamba.from_pretrained(pretrained_model_name).cuda()
input = (torch.rand(1, 256) * 50000).long().cuda()
times = []
for i in range(100):
start = time.time()
output = model(input)
if i < 10:
continue
times.append(time.time() - start)
print(np.array(times).mean()) you can also wrap the model around |
You are missing syncs there. To avoid these and other issues, consider using the benchmark suite within PyTorch https://pytorch.org/tutorials/recipes/recipes/benchmark.html On a different note, no, we do not currently allow fusions between scans and matmuls or between matmuls. As a side note, I would consider using |
What heuristics? This doesn't seem ideal 🤔 |
The heuristic that may choose to dispatch between |
The abesence of this operation has caused many state-of-the-art reinforcement learning memory models to not be able to be implemented efficiently in torch. Consequently this has caused many users to migrate to JAX to achieve state-of-the-art performance. If you want to read more about this use case, plese refer to pytorch/rl#2325 |
@matteobettini Thank you for bringing this up. We have been working on a generic_scan version and it has progressed quite a bit. Would this be helpful for TorchRL as well? I am also wondering about the RNNs that you mentioned. There are also works to make torch.while available. Would this maybe also help? In that case, the RNN could be captured as the |
hey @bohnstingl, thanks for the answer. The generic scan looks promising. Ideally what we need is a parallel associative scan that we are able to differenciate through and has cuda support. I am not sure if the operation of these models is "pure" but I imagine so. Regarding the while loop, I think that would definitely help with the implementation of any recurrent model as a for loop. |
@matteobettini I think parallel associative scan is merged in, it only supports pointwise cells, but that's not a problem for common models like S5 IIUC |
As far as I know it does not support autograd? Which makes it still useful for computing GAE or cumsums but not in nn models Please correct me if I am wrong |
Not yet, I believe. However the PR that I mentioned will enable AutoGrad |
@bohnstingl Thanks for the update |
Any news on this? I am interested in |
Yes, together with @ydwu4 and others we have made quite some progress on the
|
Thank you for the update @bohnstingl! I'm really looking forward to backward reaching the main branch and a subsequent feature release in the future (fingers crossed for 2.7)! Great work everyone working on this as I feel like scan is one of the most important operations missing from torch right now! |
@bohnstingl Yes I am trying to eventually not use the original selective scan interface that it will require |
@bohnstingl Do you have this code in a GitHub repo? |
Hi @bhack,
Let me know what you think. In addition, what worries me a bit is the implementation of the backward path that I have currently implemented. It follows this blog post, but it may not be ideal. Any thoughts on this are highly appreciated. |
🚀 The feature, motivation and pitch
It would be great to have a general parallel prefix sum (associative scan) operation in PyTorch, something like associative_scan in JAX or scan_associative in TensorFlow Probability. This operation is key for the parallelization of some algorithms in CRFs, filtering/smoothing in state space models, etc.
Alternatives
I found this implementation but it's only for computing the prefix sum and not for general associative binary operations. It would be great to have native support for arbitrary binary operators.
Additional context
No response
cc @ezyang @gchanan @zou3519 @soumith @msaroufim @wconstab @ngimel @bdhirsh @anijain2305 @Chillee @samdow @kshitij12345 @janeyx99
The text was updated successfully, but these errors were encountered: