8000 RFC-0035-viterbi-decoding.md by CameronChurchwell · Pull Request #62 · pytorch/rfcs · GitHub
[go: up one dir, main page]

Skip to content
8000

RFC-0035-viterbi-decoding.md #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Torch-style comments
  • Loading branch information
maxrmorrison committed Mar 2, 2024
commit fb66ee215e441ae8ef9c78ad6a012f053eba87fc
232 changes: 112 additions & 120 deletions RFC-0035-viterbi-decoding.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ We want to add Viterbi decoding to PyTorch. Viterbi decoding is a well-known alg

## **Motivation**

Viterbi decoding is a generally useful algorithm that is missing from the PyTorch library, with applications in automatic speech recognition, bioinformatics, digital communications, and more. However, Viterbi decoding is O(C^2T) for C classes and T timesteps, making it challenging to scale to large datasets and real-time applications. A commonly-used implementation of Viterbi decoding exists in Librosa (`librosa.sequence.viterbi`). We use Librosa's implementation as a reference for correctness and as a baseline for benchmarking. Our benchmark uses `C = 1,440` states and approximately `T ~= 20 million` time steps across approximately 40k files.

We use Viterbi decoding to decode distributions over pitch inferred by a pitch estimating neural network. We compare our proposed implementation to the reference implementation in Librosa ([`librosa.sequence.viterbi`](https://librosa.org/doc/main/generated/librosa.sequence.viterbi.html)) that uses just-in-time compilation via numba.
Viterbi decoding is a generally useful algorithm that is missing from the PyTorch library, with applications in automatic speech recognition, bioinformatics, digital communications, and more. However, Viterbi decoding is O(C^2T) for C classes and T timesteps, making it challenging to scale to large datasets and real-time applications. A commonly-used implementation of Viterbi decoding exists in Librosa (`librosa.sequence.viterbi`). We use Librosa's implementation as a reference for correctness and as a baseline for benchmarking. Our benchmark uses `C = 1,440` states and approximately `T ~= 20 million` time steps across approximately 40k files. We compare our proposed implementation to the reference implementation in Librosa ([`librosa.sequence.viterbi`](https://librosa.org/doc/main/generated/librosa.sequence.viterbi.html)) that uses just-in-time compilation via numba.

| Method | Timesteps decoded per second |
| ------------- | ------------- |
Expand Down Expand Up @@ -77,33 +75,48 @@ We propose a Python API and underlying C++/CUDA extensions for Viterbi decoding
```
def decode(
observation: torch.Tensor,
batch_frames: Optional[torch.Tensor] = None,
transition: Optional[torch.Tensor] = None,
initial: Optional[torch.Tensor] = None,
log_probs: bool = False
) -> torch.Tensor:
batch_frames: torch.Tensor,
transition: torch.Tensor,
initial: torch.Tensor
):
"""Decode a time-varying categorical distribution

Arguments
observation
Args:
observation: :math:`(N, T, S)` or :math:`(T, S)`
where `S = the number of states`,
`T = the length of the sequence`,
and `N = batch size`.
Time-varying categorical distribution
shape=(batch, frames, states)
batch_frames
Number of frames in each batch item; defaults to all
shape=(batch,)
transition
Categorical transition matrix; defaults to uniform
shape=(states, states)
initial
Categorical initial distribution; defaults to uniform
shape=(states,)
log_probs
Whether inputs are in (natural) log space

Returns
indices
batch_frames :math:`(N)`
Sequence length of each batch item
transition :math:`(S, S)`
Categorical transition matrix
initial :math:`(S)`
Categorical initial distribution

Return:
indices: :math:`(N, T)`
The decoded bin indices
shape=(batch, frames)

Example::

>>> observation = torch.tensor([[
>>> [0.25, 0.5, 0.25],
>>> [0.25, 0.25, 0.5],
>>> [0.33, 0.33, 0.33]
>>> ]])
>>> batch_frames = torch.tensor([3])
>>> transition = torch.tensor([
>>> [0.5, 0.25, 0.25],
>>> [0.33, 0.34, 0.33],
>>> [0.25, 0.25, 0.5]
>>> ])
>>> initial 8000 = torch.tensor([0.4, 0.35, 0.25])
>>> bins = torch.viterbi.decode(
>>> observation,
>>> batch_frames,
>>> transition,
>>> initial)
"""
```

Expand All @@ -112,35 +125,29 @@ def decode(

```
def make_trellis(
self,
observation: torch.Tensor,
batch_frames: Optional[torch.Tensor] = None,
transition: Optional[torch.Tensor] = None,
initial: Optional[torch.Tensor] = None,
log_probs: bool = False
batch_frames: torch.Tensor,
transition: torch.Tensor,
initial: torch.Tensor
) -> torch.Tensor:
"""Perform first step of Viterbi decoding to construct the path trellis

Arguments
observation
Args:
observation: :math:`(N, T, S)` or :math:`(T, S)`
where `S = the number of states`,
`T = the length of the sequence`,
and `N = batch size`.
Time-varying categorical distribution
shape=(batch, frames, states)
batch_frames
Number of frames in each batch item; defaults to all
shape=(batch,)
transition
Categorical transition matrix; defaults to uniform
shape=(states, states)
initial
Categorical initial distribution; defaults to uniform
shape=(states,)
log_probs
Whether inputs are in (natural) log space

Returns
trellis
The matrix of greedy path pointers used to decode the optimal path
shape=(batch, frames, states)
batch_frames :math:`(N)`
Sequence length of each batch item
transition :math:`(S, S)`
Categorical transition matrix
initial :math:`(S)`
Categorical initial distribution

Return:
trellis: :math:`(N, T, S)`
Matrix of minimum path indices for backtracing
"""
```

Expand All @@ -150,30 +157,25 @@ def make_trellis(
```
def backtrace_trellis(
trellis: torch.Tensor,
batch_frames: Optional[torch.Tensor] = None,
transition: Optional[torch.Tensor] = None,
initial: Optional[torch.Tensor] = None
batch_frames: torch.Tensor,
transition: torch.Tensor,
initial: torch.Tensor
) -> torch.Tensor:
"""Perform second step of Viterbi decoding to backtrace optimal path

Arguments
trellis
The matrix of greedy path pointers used to decode the optimal path
shape=(batch, frames, states)
batch_frames
Number of frames in each batch item; defaults to all
shape=(batch,)
transition
Categorical transition matrix; defaults to uniform
shape=(states, states)
initial
Categorical initial distribution; defaults to uniform
shape=(states,)

Returns
indices
Args:
trellis: :math:`(N, T, S)`
Matrix of minimum path indices for backtracing
batch_frames :math:`(N)`
Sequence length of each batch item
transition :math:`(S, S)`
Categorical transition matrix
initial :math:`(S)`
Categorical initial distribution

Return:
indices: :math:`(N, T)`
The decoded bin indices
shape=(batch, frames)
"""
```

Expand All @@ -186,86 +188,75 @@ class Decoder:

def __init__(
self,
transition: Optional[torch.Tensor] = None,
initial: Optional[torch.Tensor] = None
transition: torch.Tensor,
initial: torch.Tensor
) -> None:
"""
Arguments
transition
Categorical transition matrix; defaults to uniform
shape=(states, states)
initial
Categorical initial distribution; defaults to uniform
shape=(states,)
Args:
transition :math:`(S, S)`
Categorical transition matrix
initial :math:`(S)`
Categorical initial distribution
"""

def decode(
self,
observation: torch.Tensor,
batch_frames: Optional[torch.Tensor] = None,
log_probs: bool = False
batch_frames: torch.Tensor
) -> torch.Tensor:
"""Decode a time-varying categorical distribution

Arguments
observation
Args:
observation: :math:`(N, T, S)` or :math:`(T, S)`
where `S = the number of states`,
`T = the length of the sequence`,
and `N = batch size`.
Time-varying categorical distribution
shape=(batch, frames, states)
batch_frames
Number of frames in each batch item; defaults to all
shape=(batch,)
log_probs
Whether inputs are in (natural) log space

Returns
indices
batch_frames :math:`(N)`
Sequence length of each batch item

Return:
indices: :math:`(N, T)`
The decoded bin indices
shape=(batch, frames)
"""

def make_trellis(
self,
observation: torch.Tensor,
batch_frames: Optional[torch.Tensor] = None,
log_probs: bool = False
batch_frames: torch.Tensor
) -> torch.Tensor:
"""Perform first step of Viterbi decoding to construct the path trellis

Arguments
observation
Args:
observation: :math:`(N, T, S)` or :math:`(T, S)`
where `S = the number of states`,
`T = the length of the sequence`,
and `N = batch size`.
Time-varying categorical distribution
shape=(batch, frames, states)
batch_frames
Number of frames in each batch item; defaults to all
shape=(batch,)
log_probs
Whether inputs are in (natural) log space

Returns
trellis
The matrix of greedy path pointers used to decode the optimal path
shape=(batch, frames, states)
batch_frames :math:`(N)`
Sequence length of each batch item

Return:
trellis: :math:`(N, T, S)`
Matrix of minimum path indices for backtracing
"""

def backtrace_trellis(
self,
trellis: torch.Tensor,
batch_frames: Optional[torch.Tensor] = None
batch_frames: torch.Tensor
) -> torch.Tensor:
"""Perform second step of Viterbi decoding to backtrace optimal path

Arguments
trellis
The matrix of greedy path pointers used to decode the optimal path
shape=(batch, frames, states)
batch_frames
Number of frames in each batch item; defaults to all
shape=(batch,)
Args:
trellis: :math:`(N, T, S)`
Matrix of minimum path indices for backtracing
batch_frames :math:`(N)`
Sequence length of each batch item

Returns
indices
The decoded bin indices
shape=(batch, frames)
Return:
trellis: :math:`(N, T, S)`
Matrix of minimum path indices for backtracing
"""
```

Expand All @@ -291,6 +282,7 @@ Because we use only a single block per input sequence, we can process a batch of


## **Discussion questions**

* Are there desired changes in the naming conventions?
* Right now our implementation is written as a PyTorch extension. How can it be converted to something like a `TORCH_MODULE_FRAGMENT`?
* Are there recommended methods for ensuring compliance over a set of allowed dtypes? Our implementation currently works for torch.float32, but is not guaranteed to work for all types.
Expand Down
0