8000 Parallel Associative Scan · Issue #95408 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Parallel Associative Scan #95408
Open
Open
@abdulfatir

Description

@abdulfatir

🚀 The feature, motivation and pitch

It would be great to have a general parallel prefix sum (associative scan) operation in PyTorch, something like associative_scan in JAX or scan_associative in TensorFlow Probability. This operation is key for the parallelization of some algorithms in CRFs, filtering/smoothing in state space models, etc.

Alternatives

I found this implementation but it's only for computing the prefix sum and not for general associative binary operations. It would be great to have native support for arbitrary binary operators.

Additional context

No response

cc @ezyang @gchanan @zou3519 @soumith @msaroufim @wconstab @ngimel @bdhirsh @anijain2305 @Chillee @samdow @kshitij12345 @janeyx99

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: functorchPertaining to torch.func or pytorch/functorchmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0