8000 Parallel Associative Scan · Issue #95408 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Parallel Associative Scan #95408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
abdulfatir opened this issue Feb 23, 2023 · 56 comments
Open

Parallel Associative Scan #95408

abdulfatir opened this issue Feb 23, 2023 · 56 comments
Labels
feature A request for a proper, new feature. module: functorch Pertaining to torch.func or pytorch/functorch module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@abdulfatir
Copy link
abdulfatir commented Feb 23, 2023

🚀 The feature, motivation and pitch

It would be great to have a general parallel prefix sum (associative scan) operation in PyTorch, something like associative_scan in JAX or scan_associative in TensorFlow Probability. This operation is key for the parallelization of some algorithms in CRFs, filtering/smoothing in state space models, etc.

Alternatives

I found this implementation but it's only for computing the prefix sum and not for general associative binary operations. It would be great to have native support for arbitrary binary operators.

Additional context

No response

cc @ezyang @gchanan @zou3519 @soumith @msaroufim @wconstab @ngimel @bdhirsh @anijain2305 @Chillee @samdow @kshitij12345 @janeyx99

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Feb 23, 2023

@abdulfatir
Copy link
Author

I think scan and associative_scan are two different things. associative_scan implements a general parallel prefix-sum-like algorithm.

@abdulfatir
Copy link
Author

@PeaBrane
Copy link

Some more context

This method would be incredibly useful for training a class of modern recurrent networks based on linear state-space models, that was able to achieve state-of-the-art results on long-sequence prediction tasks, e.g. the long range arena.

More details are available in Appendix H of this paper which used the jax associative_scan method to train it originally.

@harpone
Copy link
harpone commented Apr 23, 2023

@abdulfatir @PeaBrane I think this is what was used in TF & JAX: https://github.com/eamartin/parallelizing_linear_rnns

@hypnopump
Copy link

this would indeed be very useful (ex. https://arxiv.org/abs/2305.13048)

@cpuhrsch cpuhrsch added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 31, 2023
@janeyx99 janeyx99 8000 added oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jun 5, 2023
@bohnstingl
Copy link
Collaborator

@abdulfatir
Associative scan has been implemented in Triton, it is working and much quicker than the PT2 code that I could come up with.

import torch
import numpy as np
import time
import triton
import triton.language as tl
from triton.runtime.jit import TensorWrapper, reinterpret

int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes
dtypes_with_bfloat16 = dtypes + ['bfloat16']
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
    
def to_triton(x: np.ndarray, device='cuda', dst_type=None):
    t = x.dtype.name
    if t in uint_dtypes:
        signed_type_name = t.lstrip('u')  # e.g. "uint16" -> "int16"
        x_signed = x.astype(getattr(np, signed_type_name))
        return reinterpret(torch.tensor(x_signed, device=device).contiguous(), getattr(tl, t))
    else:
        if dst_type and 'float8' in dst_type:
            return reinterpret(torch.tensor(x, device=device).contiguous(), getattr(tl, dst_type))
        if t == 'float32' and dst_type == 'bfloat16':
            return torch.tensor(x, device=device).contiguous().bfloat16()
        return torch.tensor(x, device=device).contiguous()
    
def to_numpy(x):
    if isinstance(x, TensorWrapper):
        return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
    elif isinstance(x, torch.Tensor):
        if x.dtype is torch.bfloat16:
            return x.cpu().float().numpy()
        return x.cpu().numpy()
    else:
        raise ValueError(f"Not a triton-compatible tensor: {x}")
    
if __name__ == "__main__":
    use_gpu = True

    if use_gpu:
        device = torch.device('cuda:0')
    else:
        device = None

    triton_times = []
    loop_times = []
    loop_comp_times = []
    vals_to_compare = []

    op = 'cumsum'
    num_warps = 16
    dtype_str = 'float32'
    axis = 0
    shape = (1024, 1)
    n_timings = 10

    x = np.random.rand(*shape).astype(dtype=np.float32)
    inp = torch.tensor(x, device=device, requires_grad=True, dtype=torch.float32)
    init = torch.zeros(shape[1], 1, device=device, requires_grad=True)
    inp_scan = inp

    @triton.jit
    def sum_op(a, b):
        return a + b

    @triton.jit
    def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
        range_m = tl.arange(0, BLOCK_M)
        range_n = tl.arange(0, BLOCK_N)
        x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
        #tl.device_print("z", x)
        z = tl.associative_scan(x, 0, sum_op)
        #tl.device_print("z", z)
        tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)

    z = np.empty_like(x)
    x_tri = to_triton(x, device=device)
    numpy_op = np.cumsum
    z_dtype_str = dtype_str
    z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
    # triton result
    z_tri = to_triton(z, device=device)
    val = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)
    out_triton = to_numpy(z_tri)
    vals_to_compare.append(out_triton)

    for _ in range(n_timings):
        start = time.time_ns()
        kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)
        stop = time.time_ns()
        triton_times.append((stop - start) / (10 ** 9))

    def f(carry, x):
        return carry+x, carry+x

    def _fake_scan(f, init, x):
        zs = []
        carry = init
        for xp in x:
            carry, out = f(carry, xp)
            zs.append(out)
        return carry, torch.stack(zs)

    expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan)
    out_loop = expected_ys[:, 0, :]
    vals_to_compare.append(out_loop.cpu().detach().numpy())

    for _ in range(n_timings):
        start = time.time_ns()
        expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan)
        stop = time.time_ns()
        loop_times.append((stop - start) / (10 ** 9))

    _fake_scan_comp = torch.compile(_fake_scan, mode='reduce-overhead', fullgraph=True, dynamic=False)

    #Warm-up cycles
    for _ in range(5):
        expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan)

    expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan)
    out_loop_comp = expected_ys_comp[:, :, 0]
    vals_to_compare.append(out_loop_comp.cpu().detach().numpy())

    for _ in range(n_timings):
        start = time.time_ns()
        expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan)
        stop = time.time_ns()
        loop_comp_times.append((stop - start) / (10 ** 9))

    #Check all results for deviations
    for ind1, res1 in enumerate(vals_to_compare):
        for ind2, res2 in enumerate(vals_to_compare):
            if not np.allclose(res1, res2):
                print((ind1, res1))
                print((ind2, res2))
                raise Exception('Comparison of ' + str(ind1) + ' with ' + str(ind2) + ' failed!')

    print('Times regular loop ' + str(np.array(loop_times).mean()))
    print('Times compiled loop ' + str(np.array(loop_comp_times).mean()))
    print('Times triton ' + str(np.array(triton_times).mean()))
    print('Script ended')

and the times are

Times regular loop 0.0203269748
Times compiled loop 0.0073024944
Times triton 2.8002099999999997e-05
Script ended

Can the associative scan from triton be realized in PyTorch with compilation?

@Chillee
Copy link
Collaborator
Chillee commented Jul 27, 2023

yeah that's the plan @bohnstingl :)

cc: @ezyang @zou3519 @peterbell10

@smorad
Copy link
Contributor
smorad commented Aug 10, 2023

Is there any eta on when this will be available? I have some torch code that requires the associative scan, and I'm deciding whether to rewrite it in jax or wait for a torch associative scan.

@lucidrains
Copy link

https://arxiv.org/abs/2311.01927

@i404788
Copy link
i404788 commented Nov 7, 2023

Not sure if this helps but I ported the associative_scan algorithm from jax.lax to pytorch for my S5 impl:
https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134

It only needs the pytree things to be converted to the pytorch internal ones to lose the jax dependency.

@harpone
Copy link
harpone commented Nov 7, 2023

@i404788 that's excellent! Do you have some benchmarks for your associative_scan vs a compiled loop? (with different size batch dimensions too preferably)

@smorad
Copy link
Contributor
smorad commented Nov 7, 2023

https://arxiv.org/abs/2311.01927

There are a ton of models that need this. Basically any linear-complexity sequence model published in the past few years requires an associative scan. S5, FFM, etc. I gave up and reimplemented my whole library in jax because torch was missing this.

@i404788
Copy link
i404788 commented Nov 7, 2023

@harpone I didn't benchmark it but I've adapted @bohnstingl's script (torch.compile commented because my gpu is too old 😅 ):

import torch
import numpy as np
import time
import triton
import triton.language as tl
from triton.runtime.jit import TensorWrapper, reinterpret
from s5.jax_compat import associative_scan

int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes
dtypes_with_bfloat16 = dtypes + ['bfloat16']
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']


def to_triton(x: np.ndarray, device='cuda', dst_type=None):
    t = x.dtype.name
    if t in uint_dtypes:
        signed_type_name = t.lstrip('u')  # e.g. "uint16" -> "int16"
        x_signed = x.astype(getattr(np, signed_type_name))
        return reinterpret(torch.tensor(x_signed, device=device).contiguous(), getattr(tl, t))
    else:
        if dst_type and 'float8' in dst_type:
            return reinterpret(torch.tensor(x, device=device).contiguous(), getattr(tl, dst_type))
        if t == 'float32' and dst_type == 'bfloat16':
            return torch.tensor(x, device=device).contiguous().bfloat16()
        return torch.tensor(x, device=device).contiguous()


def to_numpy(x):
    if isinstance(x, TensorWrapper):
        # FIXME: torch_dtype_name doesn't exist
        return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
    elif isinstance(x, torch.Tensor):
        if x.dtype is torch.bfloat16:
            return x.cpu().float().numpy()
        return x.cpu().numpy()
    else:
        raise ValueError(f"Not a triton-compatible tensor: {x}")


if __name__ == "__main__":
    use_gpu = True

    if use_gpu:
        device = torch.device('cuda:0')
    else:
        device = None

    triton_times = []
    loop_times = []
    loop_comp_times = []
    jax_compat_times = []

    print("Initializing")
    op = 'cumsum'
    num_warps = 16

    dim = 1
    seq_len = 1024
    batch = 1

    dtype_str = 'float32'
    axis = 0
    shape = (batch, seq_len, dim)
    n_timings = 100

    x = np.random.rand(*shape).astype(dtype=np.float32)
    inp = torch.tensor(x, device=device, requires_grad=True, dtype=torch.float32)
    init = torch.zeros(shape[1], 1, device=device, requires_grad=True)
    inp_scan = inp

    @triton.jit
    def sum_op(a, b):
        return a + b

    @triton.jit
    def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
        range_m = tl.arange(0, BLOCK_M)
        range_n = tl.arange(0, BLOCK_N)
        x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
        #tl.device_print("z", x)
        z = tl.associative_scan(x, 0, sum_op)
        #tl.device_print("z", z)
        tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)

    print("Triton")
    z = np.empty_like(x)
    x_tri = to_triton(x, device=device)
    numpy_op = np.cumsum
    z_dtype_str = dtype_str
    z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
    # triton result
    z_tri = to_triton(z, device=device)
    val = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)
    out_triton = to_numpy(z_tri)

    for _ in range(n_timings):
        print('.', end='', flush=True)
        start = time.monotonic_ns()
        kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)
        stop = time.monotonic_ns()
        triton_times.append((stop - start) / (10 ** 9))

    print("\nFake scan")
    def f(carry, x):
        return carry+x, carry+x

    def _fake_scan(f, init, x):
        zs = []
        carry = init
        for xp in x:
            carry, out = f(carry, xp)
            zs.append(out)
        return carry, torch.stack(zs)

    expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan)

    for _ in range(n_timings):
        print('.', end='', flush=True)
        start = time.monotonic_ns()
        expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan)
        stop = time.monotonic_ns()
        loop_times.append((stop - start) / (10 ** 9))

    # _fake_scan_comp = torch.compile(_fake_scan, mode='reduce-overhead', fullgraph=True, dynamic=False)

    # # Warm-up cycles
    # print("\nFake scan-compiled")
    # for _ in range(5):
    #     expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan)

    # for _ in range(n_timings):
    #     print('.', end='', flush=True)
    #     start = time.monotonic_ns()
    #     expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan)
    #     stop = time.monotonic_ns()
    #     loop_comp_times.append((stop - start) / (10 ** 9))

    def sum_op2(a, b):
        return a+b, a + b

    # Warm-up
    print("\njax_compat")
    for _ in range(5):
        expected_ys_comp = associative_scan(sum_op2, inp_scan, axis=-1)

    for _ in range(n_timings):
        print('.', end='', flush=True)
        start = time.monotonic_ns()
        expected_ys_comp = associative_scan(sum_op2, inp_scan, axis=-1)
        stop = time.monotonic_ns()
        jax_compat_times.append((stop - start) / (10 ** 9))

    print()
    print('Times regular loop ' + str(np.array(loop_times).mean()))
    # print('Times compiled loop ' + str(np.array(loop_comp_times).mean()))
    print('Times triton ' + str(np.array(triton_times).mean()))
    print('Times jax_compat ' + str(np.array(jax_compat_times).mean()))
    print('Script ended')

Output (n_timings=100, seq_len=1024, batch=1, dim=1):

Initializing
Triton
....................................................................................................
Fake scan
....................................................................................................
jax_compat
....................................................................................................
Times regular loop 5.178356e-05
Times triton 3.0721640000000004e-05
Times jax_compat 5.173079999999998e-06
Script ended

Output(n_timings=100, seq_len=1024, batch=4, dim=1):

Initializing
Triton
....................................................................................................
Fake scan
....................................................................................................
jax_compat
....................................................................................................
Times regular loop 0.00010100867000000001
Times triton 3.103996e-05
Times jax_compat 5.17184e-06
Script ended

Someone with more VRAM should probably test it with different configs, since the time doesn't seem to change much between configs.

@PeaBrane
Copy link
PeaBrane commented Feb 9, 2024

In case if anyone is interested in an implementation of mamba selective scan, but without using parallel scan, there is a way to do it with two cumsums. I made a fork from mamba_minimal, and implemented my method in this commit. Based on the texts generated by the demo notebook, it seems to be functional.

The core code is just this:

def selective_scan(self, u, dt, A, B, C, D):
        dA = torch.einsum('bld,dn->bldn', dt, A)
        dB_u = torch.einsum('bld,bld,bln->bldn', dt, u, B)
        
        dA_cumsum = F.pad(dA[:, 1:], (0, 0, 0, 0, 0, 1)).flip(1).cumsum(1).exp().flip(1)
        x = dB_u * dA_cumsum
        x = x.cumsum(1) / (dA_cumsum + 1e-12)
        y = torch.einsum('bldn,bln->bld', x, C)
    
        return y + u * D

Even though the implementation may not be optimal, it should be somewhat comparable to the original implementation, assuming that torch.compile can do kernel fusion properly. Here are the inference times for a (1, 256) input with mamba-370m architecture on an A30:

mamba_minimal: 1.293 s
mamba_cumsum: 0.0934 s
mamba_cumsum_compiled: 0.0903

Edit: I realized this is potential just heisen_sequence in non-log space, which is perhaps also related to @Algomancer's approach.

@Chillee
Copy link
Collaborator
Chillee commented Feb 9, 2024

@PeaBrane do you have sample benchmarking code?

@PeaBrane
Copy link
PeaBrane commented Feb 9, 2024

@Chillee I just went into the mamba_minimal and mamba_tiny repos separately and ran the following script

import time

import numpy as np
import torch

from model import Mamba


pretrained_model_name = 'state-spaces/mamba-370m'
model = Mamba.from_pretrained(pretrained_model_name).cuda()
input = (torch.rand(1, 256) * 50000).long().cuda()

times = []
for i in range(100):
    start = time.time()
    output = model(input)
    if i < 10:
        continue
    times.append(time.time() - start)
    
print(np.array(times).mean())

you can also wrap the model around torch.compile for additional testings.
I did not have time to benchmark the original mamba implementation, but probably should have

@lezcano
Copy link
Collaborator
lezcano commented Feb 9, 2024

You are missing syncs there. To avoid these and other issues, consider using the benchmark suite within PyTorch https://pytorch.org/tutorials/recipes/recipes/benchmark.html

On a different note, no, we do not currently allow fusions between scans and matmuls or between matmuls.

As a side note, I would consider using torch.matmul over einsum as it implements a number of extra heuristics. Also, if you are going to benchmark things that involve matmuls, you may want topass max-autotune as the mode for torch.compile.

@Chillee
Copy link
Collaborator
Chillee commented Feb 9, 2024

As a side note, I would consider using torch.matmul over einsum as it implements a number of extra heuristics.

What heuristics? This doesn't seem ideal 🤔

@lezcano
Copy link
Collaborator
lezcano commented Feb 10, 2024

The heuristic that may choose to dispatch between mm and bmm for tensors of rank 3 and 2. But perhaps these are not hit within einsum?

@bohnstingl
Copy link
Collaborator

@Chillee @lezcano, what is the status of this feature request? Is there anything that I can be of help? I would be very much interested bringing the generic associative scan operation from triton into PyTorch.

@matteobettini
Copy link

The abesence of this operation has caused many state-of-the-art reinforcement learning memory models to not be able to be implemented efficiently in torch.

Consequently this has caused many users to migrate to JAX to achieve state-of-the-art performance.

If you want to read more about this use case, plese refer to pytorch/rl#2325

@bohnstingl
Copy link
Collaborator

@matteobettini Thank you for bringing this up. We have been working on a generic_scan version and it has progressed quite a bit. Would this be helpful for TorchRL as well?

I am also wondering about the RNNs that you mentioned. There are also works to make torch.while available. Would this maybe also help? In that case, the RNN could be captured as the body of the loop and the loop handles the variable time lenghts?

@matteobettini
Copy link

hey @bohnstingl, thanks for the answer.

The generic scan looks promising. Ideally what we need is a parallel associative scan that we are able to differenciate through and has cuda support. I am not sure if the operation of these models is "pure" but I imagine so.

Regarding the while loop, I think that would definitely help with the implementation of any recurrent model as a for loop.
This could provide benefits for models that are not expressible as an associative scan (or even the ones that are, while waiting for the scan)

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Jul 29, 2024

@matteobettini I think parallel associative scan is merged in, it only supports pointwise cells, but that's not a problem for common models like S5 IIUC

@matteobettini
Copy link
matteobettini commented Jul 29, 2024

@matteobettini I think parallel associative scan is merged in, it only supports pointwise cells, but that's not a problem for common models like S5 IIUC

As far as I know it does not support autograd?

Which makes it still useful for computing GAE or cumsums but not in nn models

Please correct me if I am wrong

@bohnstingl
Copy link
Collaborator

Not yet, I believe. However the PR that I mentioned will enable AutoGrad

@bohnstingl
Copy link
Collaborator

With the great help of @ydwu4 and @Chillee, the support for non-pointwise functions for associative_scan has landed via this PR. However, it does not yet support Autograd, but I am preparing an implementation of that as we speak.

@grizzlybearg
Copy link

@bohnstingl Thanks for the update

@bhack
Copy link
Contributor
bhack commented Dec 14, 2024

Any news on this? I am interested in selective_scan.

@bohnstingl
Copy link
Collaborator

Yes, together with @ydwu4 and others we have made quite some progress on the associative_scan operation.
The code for the function in the latest main branch is here. There are two combine_modes, pointwise and generic. As the name suggests, pointwise should be used for pointwise functions and generic for any other associative functions.
Is there any specific model that you are looking after? If you are referring to the selective_scan operation that is used in Mamba for example, this can be implemented and I have a small code snippet for that already.

associative_scan currently supports only the forward mode, but the autograd is already implemented and is awaiting merge

@Avelina9X
Copy link

Thank you for the update @bohnstingl! I'm really looking forward to backward reaching the main branch and a subsequent feature release in the future (fingers crossed for 2.7)! Great work everyone working on this as I feel like scan is one of the most important operations missing from torch right now!

@bhack
Copy link
Contributor
bhack commented Dec 14, 2024

@bohnstingl Yes I am trying to eventually not use the original selective scan interface that it will require custom_ops for torch compile/export. See #130150 (comment)

@bhack
Copy link
Contributor
bhack commented Jan 26, 2025

If you are referring to the selective_scan operation that is used in Mamba for example, this can be implemented and I have a small code snippet for that already.

@bohnstingl Do you have this code in a GitHub repo?

@bohnstingl
Copy link
Collaborator

Hi @bhack,
So, I used this repo here from the awesome group of @tridao as a starting point and modified the selective_scan_ref function according to the snippet below:

def selective_scan_torch(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                      return_last_state=False):
    """
    u: r(B D L)
    delta: r(B D L)
    A: c(D N) or r(D N)
    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32

    out: r(B D L)
    last_state (optional): r(B D dstate) or c(B D dstate)
    """
    from torch._higher_order_ops.associative_scan import (
        associative_scan,
    )
    
    def s5_operator(x, y):
        A_i, Bu_i = x
        A_j, Bu_j = y
        return A_j * A_i, A_j * Bu_i + Bu_j
    
    use_associative_scan = True
    
    def _selective_scan_torch(u, delta, A, B, C, D, z, delta_bias, delta_softplus):
        dtype_in = u.dtype
        u = u.float()
        delta = delta.float()
        if delta_bias is not None:
            delta = delta + delta_bias[..., None].float()
        if delta_softplus:
            delta = F.softplus(delta)
        batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
        is_variable_B = B.dim() >= 3
        is_variable_C = C.dim() >= 3
        if A.is_complex():
            if is_variable_B:
                B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
            if is_variable_C:
                C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
        else:
            B = B.float()
            C = C.float()

        deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
        if not is_variable_B:
            deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
        else:
            if B.dim() == 3:
                deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
            else:
                B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
                deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
        if is_variable_C and C.dim() == 4:
            C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
        last_state_scan = None
        
        if use_associative_scan:
            _, x_scan = associative_scan(s5_operator, (deltaA, deltaB_u), 2, combine_mode=combine_mode)
            if not is_variable_C:
                raise NotImplementedError('This feature is not yet implemented!')
                y = torch.einsum('bdn,dn->bd', x_scan, C)
            else:
                if C.dim() == 3:
                    y_scan = torch.einsum('bdsn,bns->bds', x_scan, C)
                else:
                    raise NotImplementedError('This feature is not yet implemented!')
                    y = torch.einsum('bdns,bdns->bds', x_scan, C)
            last_state_scan = x_scan[:, :, -1, :]
            if y_scan.is_complex():
                y_scan = y_scan.real * 2
        else:
            pass
        
        out = y_scan if D is None else y_scan + u * rearrange(D, "d -> d 1")
        if z is not None:
            out = out * F.silu(z)
        out = out.to(dtype=dtype_in)
        
        return out, last_state_scan
    
    comment = '_compile'
    # combine_mode = 'generic'
    combine_mode = 'pointwise'
    
    if 'compile' in comment:
        _selective_scan_torch_cmp = torch.compile(_selective_scan_torch, fullgraph=True, mode='reduce-overhead')
    else:
        _selective_scan_torch_cmp = _selective_scan_torch
    
    out, last_state = _selective_scan_torch_cmp(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
    
    return out if not return_last_state else (out, last_state)

Let me know what you think.

In addition, what worries me a bit is the implementation of the backward path that I have currently implemented. It follows this blog post, but it may not be ideal. Any thoughts on this are highly appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: functorch Pertaining to torch.func or pytorch/functorch module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

0