-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Making Mamba first-class citizen in PyTorch #120189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
The document, as written, is a bit too optimistic. The current We should be able to do this once:
|
@lezcano Yes, we only discussed SSMs for diagonal matrices in the doc, but we should be able to extend to more general SSMs after this. |
This PR implements a `reverse` argument for associative scan similar to the jax implementation. While this can be implemented using the tl.flip command, @Jokeren advised me that this would be very inefficient and that this should be done in the associative scan itself. The implementation can be summarized as `flip(scan(flip(x)))`. However the flip needs to happen along three axes: warp, lanes, chunks. To flip the chunks, I simply reverse the vector of values. To flip the lanes, I use a butterfly shuffle to efficiently reverse the lanes. To flip the warp (needed for the slow case) I flip the indexing of the warps themselves. I additionally modified the scan tests to include the new reverse implementation. ## Why is this needed? This was needed originally for the implementation of the Mamba model (https://srush.github.io/annotated-mamba/hard.html) to compute the backward pass of the models. I thought pretty hard about whether this could be done by any kind of recomputation, but it seems pretty necessary to be able to do a reverse accumulation in order to take a dot product in the kernel. (perhaps also relevant to pytorch/pytorch#120189 )
…3177) This PR implements a `reverse` argument for associative scan similar to the jax implementation. While this can be implemented using the tl.flip command, @Jokeren advised me that this would be very inefficient and that this should be done in the associative scan itself. The implementation can be summarized as `flip(scan(flip(x)))`. However the flip needs to happen along three axes: warp, lanes, chunks. To flip the chunks, I simply reverse the vector of values. To flip the lanes, I use a butterfly shuffle to efficiently reverse the lanes. To flip the warp (needed for the slow case) I flip the indexing of the warps themselves. I additionally modified the scan tests to include the new reverse implementation. ## Why is this needed? This was needed originally for the implementation of the Mamba model (https://srush.github.io/annotated-mamba/hard.html) to compute the backward pass of the models. I thought pretty hard about whether this could be done by any kind of recomputation, but it seems pretty necessary to be able to do a reverse accumulation in order to take a dot product in the kernel. (perhaps also relevant to pytorch/pytorch#120189 )
I hope we could include also Mamba2 coverage https://arxiv.org/abs/2405.21060 |
Associative scan was merged in Triton on April What is the status on pytorch? |
Uh oh!
There was an error while loading. Please reload this page.
🚀 The feature, motivation and pitch
Mamba is a new SSM (State Space Model) which is developed to address Transformers’ computational inefficiency on long sequences. It has attracted more attention recently due to faster inference and linear scaling in sequence length. We are exploring how to support Mamba as first-class citizens in PyTorch.
To better understand the gaps and coordinate these ongoing effects, we created the following doc to trace the requested features and issues. Feel free to comment if you have any feedback!
https://docs.google.com/document/d/1rNNByFrOjOQOBM6ZZqnOqc-LGMRmdnQifKfbY_KalnM/edit?usp=sharing
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @Chillee @ydwu4 @peterbell10 @lezcano @aakhundov @chauhang
The text was updated successfully, but these errors were encountered: