8000 Making Mamba first-class citizen in PyTorch · Issue #120189 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Making Mamba first-class citizen in PyTorch #120189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
yanboliang opened this issue Feb 19, 2024 · 5 comments
Open

Making Mamba first-class citizen in PyTorch #120189

yanboliang opened this issue Feb 19, 2024 · 5 comments
Labels
module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, needs research We need to decide whether or not this merits inclusion, based on research world oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yanboliang
Copy link
Contributor
yanboliang commented Feb 19, 2024

🚀 The feature, motivation and pitch

Mamba is a new SSM (State Space Model) which is developed to address Transformers’ computational inefficiency on long sequences. It has attracted more attention recently due to faster inference and linear scaling in sequence length. We are exploring how to support Mamba as first-class citizens in PyTorch.

To better understand the gaps and coordinate these ongoing effects, we created the following doc to trace the requested features and issues. Feel free to comment if you have any feedback!
https://docs.google.com/document/d/1rNNByFrOjOQOBM6ZZqnOqc-LGMRmdnQifKfbY_KalnM/edit?usp=sharing

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @Chillee @ydwu4 @peterbell10 @lezcano @aakhundov @chauhang

@lezcano
Copy link
Collaborator
lezcano commented Feb 19, 2024

The document, as written, is a bit too optimistic.

The current tl.associative_scan (and as such @peterbell10's implementation in #119430) just supports pointwise accumulation functions. As such, we will just be able to implement SSMs for diagonal matrices, where matrix multiplication turns into pointwise multiplication.

We should be able to do this once:

  1. [HOP][inductor] Add higher order associative scan operator #119430 lands
  2. We extend the current support for multiple inputs and outputs in our scan operation (this shouldn't be too difficult).

@yanboliang
Copy link
Contributor Author
yanboliang commented Feb 19, 2024

@lezcano Yes, we only discussed SSMs for diagonal matrices in the doc, but we should be able to extend to more general SSMs after this.

@ezyang ezyang added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module needs research We need to decide whether or not this merits inclusion, based on research world oncall: pt2 labels Feb 20, 2024
@zou3519 zou3519 added module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels Feb 22, 2024
Jokeren pushed a commit to triton-lang/triton that referenced this issue Feb 23, 2024
This PR implements a `reverse` argument for associative scan similar to
the jax implementation. While this can be implemented using the tl.flip
command, @Jokeren advised me that this would be very inefficient and
that this should be done in the associative scan itself.

The implementation can be summarized as `flip(scan(flip(x)))`. However
the flip needs to happen along three axes: warp, lanes, chunks. To flip
the chunks, I simply reverse the vector of values. To flip the lanes, I
use a butterfly shuffle to efficiently reverse the lanes. To flip the
warp (needed for the slow case) I flip the indexing of the warps
themselves.

I additionally modified the scan tests to include the new reverse
implementation.

## Why is this needed? 
This was needed originally for the implementation of the Mamba model
(https://srush.github.io/annotated-mamba/hard.html) to compute the
backward pass of the models. I thought pretty hard about whether this
could be done by any kind of recomputation, but it seems pretty
necessary to be able to do a reverse accumulation in order to take a dot
product in the kernel. (perhaps also relevant to
pytorch/pytorch#120189 )
binarman pushed a commit to binarman/triton that referenced this issue Apr 2, 2024
…3177)

This PR implements a `reverse` argument for associative scan similar to
the jax implementation. While this can be implemented using the tl.flip
command, @Jokeren advised me that this would be very inefficient and
that this should be done in the associative scan itself.

The implementation can be summarized as `flip(scan(flip(x)))`. However
the flip needs to happen along three axes: warp, lanes, chunks. To flip
the chunks, I simply reverse the vector of values. To flip the lanes, I
use a butterfly shuffle to efficiently reverse the lanes. To flip the
warp (needed for the slow case) I flip the indexing of the warps
themselves.

I additionally modified the scan tests to include the new reverse
implementation.

## Why is this needed? 
This was needed originally for the implementation of the Mamba model
(https://srush.github.io/annotated-mamba/hard.html) to compute the
backward pass of the models. I thought pretty hard about whether this
could be done by any kind of recomputation, but it seems pretty
necessary to be able to do a reverse accumulation in order to take a dot
product in the kernel. (perhaps also relevant to
pytorch/pytorch#120189 )
@bhack
Copy link
Contributor
bhack commented Jun 11, 2024

I hope we could include also Mamba2 coverage https://arxiv.org/abs/2405.21060

state-spaces/mamba#355

@bhack
Copy link
Contributor
bhack commented Jul 29, 2024

Associative scan was merged in Triton on April
triton-lang/triton#3177
A minor bug is WIP at:
triton-lang/triton#4362

What is the status on pytorch?

@bhack
Copy link
Contributor
bhack commented Dec 14, 2024

/cc @bohnstingl @ydwu4

#95408 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, needs research We need to decide whether or not this merits inclusion, based on research world oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0