Open
Description
🚀 The feature, motivation and pitch
It would be great to have a general parallel prefix sum (associative scan) operation in PyTorch, something like associative_scan in JAX or scan_associative in TensorFlow Probability. This operation is key for the parallelization of some algorithms in CRFs, filtering/smoothing in state space models, etc.
Alternatives
I found this implementation but it's only for computing the prefix sum and not for general associative binary operations. It would be great to have native support for arbitrary binary operators.
Additional context
No response
cc @ezyang @gchanan @zou3519 @soumith @msaroufim @wconstab @ngimel @bdhirsh @anijain2305 @Chillee @samdow @kshitij12345 @janeyx99