FFFF Merge branch 'develop' into better_bleu_doc · speechbrain/speechbrain@29fbe54 · GitHub
[go: up one dir, main page]

Skip to content

Commit 29fbe54

Browse files
authored
Merge branch 'develop' into better_bleu_doc
2 parents 190217c + 7724216 commit 29fbe54

File tree

9 files changed

+721
-25
lines changed

9 files changed

+721
-25
lines changed

recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ ctc_weight: 0.3 # Multitask with CTC for the encoder (0.0 = disabled)
5757
ce_weight: 0.0 # Multitask with CE for the decoder (0.0 = disabled)
5858
max_grad_norm: 5.0
5959
loss_reduction: 'batchmean'
60-
precision: fp32 # bf16, fp16 or fp32
60+
precision: fp16 # bf16, fp16 or fp32
6161

6262
# The batch size is used if and only if dynamic batching is set to False
6363
# Validation and testing are done with fixed batches and not dynamic batching.
@@ -136,6 +136,7 @@ output_neurons: 1000
136136
dec_dim: 512
137137
dec_emb_dropout: 0.2
138138
dec_dropout: 0.1
139+
attention_type: RoPEMHA
139140

140141
# Decoding parameters
141142
blank_index: 0
@@ -236,7 +237,7 @@ Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.Transforme
236237
dropout: !ref <transformer_dropout>
237238
activation: !ref <activation>
238239
encoder_module: conformer
239-
attention_type: RelPosMHAXL
240+
attention_type: !ref <attention_type>
240241
normalize_before: True
241242
causal: False
242243

recipes/LibriSpeech/ASR/transformer/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ Following table contains whisper-finetuning results for 1 epoch using Whisper mo
3737

3838
| Release | hyperparams file | Dev Clean WER (No LM, small beam) | Test Clean WER (Transformer LM) | Test Other WER (Transformer LM) | HuggingFace link | Model link | GPUs |
3939
|:-------------:|:-------------:|:-------------:|:---------------------------:| :-----:| :-----:| :-----:| :--------:|
40+
| 30-09-24 | conformer_large.yaml (new RoPE version) |1.85 with LM | 1.96 | 4.50 | Not Avail. | Not Avail. | 4xA40 46GB |
4041
| 23-05-23 | branchformer_large.yaml | 2.72 (1.9 with LM) | 2.04 | 4.13 | Not Avail. | [DropBox](https://www.dropbox.com/scl/fo/qhtds5rrdvhhhjywa7ovw/AMiIL5YvQENw5JKVpzXlP5o?rlkey=hz8vlpy3qf9kcyfx0cox089e6&st=ufckv6tb&dl=0) | 4xA100 80GB |
42+
| 10-02-25 | conformer_large.yaml | 1.85 with LM | 1.97 | 4.50 | N/A | N/A | 4xA100 80GB |
4143
| 23-05-23 | conformer_large.yaml | 2.62 (1.9 with LM) | 2.01 | 4.52 | [HuggingFace](https://huggingface.co/speechbrain/asr-conformer-transformerlm-librispeech) | [DropBox](https://www.dropbox.com/scl/fo/9we244tgdf47ay20hrdoz/AKnoqQ13nLwSv1ITeJEQ3wY?rlkey=05o5jiszr8rhj6dlprw87t2x4&st=u2odesyk&dl=0) | 4xA100 80GB |
4244
| 24-03-22 | transformer.yaml | 3.32 | 2.27 | 5.53 | [HuggingFace](https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech) | [DropBox](https://www.dropbox.com/sh/653kq8h2k87md4p/AAByAaAryXtQKpRzYtzV9ih5a?dl=0) | 4xV100 32GB |
4345
| 24-03-22 | conformer_small.yaml | 4.05 | 2.49 | 6.1 (**only 13.3M parameters**) | [HuggingFace](https://huggingface.co/speechbrain/asr-conformersmall-transformerlm-librispeech) | [DropBox](https://www.dropbox.com/sh/s0x6ni124858b8i/AAALaCH6sGTMRUVTjh8Tm8Jwa?dl=0) | 1xV100 32GB |

recipes/LibriSpeech/ASR/transformer/hparams/conformer_large.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,24 @@ max_grad_norm: 5.0
5555
loss_reduction: 'batchmean'
5656
sorting: random
5757
num_workers: 4
58-
precision: fp32 # bf16, fp16 or fp32
58+
precision: fp16 # bf16, fp16 or fp32
5959
avg_checkpoints: 10 # Number of checkpoints to average for evaluation
6060

6161
# stages related parameters
6262
lr_adam: 0.0008
63+
warmup: 50000
64+
augment_warmup: 8000
6365

6466
# Feature parameters
6567
sample_rate: 16000
6668
n_fft: 512
6769
n_mels: 80
6870
win_length: 32
6971

70-
# This setup works well for A100 80GB GPU, adapts it to your needs.
72+
# This setup works well for V100 32GB GPU, adapts it to your needs.
7173
# Or turn it off (but training speed will decrease)
7274
dynamic_batching: True
73-
max_batch_length_train: 500
75+
max_batch_length_train: 150
7476
max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
7577
num_bucket: 200
7678
shuffle: True # if true re-creates batches at each epoch shuffling examples.
@@ -153,7 +155,7 @@ Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.Transforme
153155
dropout: !ref <transformer_dropout>
154156
activation: !ref <activation>
155157
encoder_module: conformer
156-
attention_type: RelPosMHAXL
158+
attention_type: RoPEMHA
157159
normalize_before: True
158160
causal: False
159161

@@ -261,7 +263,7 @@ seq_cost: !name:speechbrain.nnet.losses.kldiv_loss
261263

262264
noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
263265
lr_initial: !ref <lr_adam>
264-
n_warmup_steps: 30000
266+
n_warmup_steps: !ref <warmup>
265267

266268
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
267269
checkpoints_dir: !ref <save_folder>

recipes/LibriSpeech/ASR/transformer/train.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,15 @@ def compute_forward(self, batch, stage):
6262
feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)
6363

6464
# Add feature augmentation if specified.
65+
augment_warmup = 0
66+
if hasattr(self.hparams, "augment_warmup"):
67+
augment_warmup = self.hparams.augment_warmup
6568
if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
66-
feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
67-
tokens_bos = self.hparams.fea_augment.replicate_labels(tokens_bos)
69+
if self.optimizer_step > augment_warmup:
70+
feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
71+
tokens_bos = self.hparams.fea_augment.replicate_labels(
72+
tokens_bos
73+
)
6874

6975
# forward modules
7076
src = self.modules.CNN(feats)
@@ -118,7 +124,13 @@ def compute_objectives(self, predictions, batch, stage):
118124
if stage == sb.Stage.TRAIN:
119125
# Labels must be extended if parallel augmentation or concatenated
120126
# augmentation was performed on the input (increasing the time dimension)
121-
if hasattr(self.hparams, "fea_augment"):
127+
augment_warmup = 0
128+
if hasattr(self.hparams, "augment_warmup"):
129+
augment_warmup = self.hparams.augment_warmup
130+
if (
131+
hasattr(self.hparams, "fea_augment")
132+
and self.optimizer_step > augment_warmup
133+
):
122134
(
123135
tokens,
124136
tokens_lens,

speechbrain/lobes/models/transformer/Conformer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* Jianyuan Zhong 2020
66
* Samuele Cornell 2021
77
* Sylvain de Langen 2023
8+
* Shucong Zhang 2024
89
"""
910

1011
import warnings
@@ -21,6 +22,7 @@
2122
MultiheadAttention,
2223
PositionalwiseFeedForward,
2324
RelPosMHAXL,
25+
RoPEMHA,
2426
)
2527
from speechbrain.nnet.hypermixing import HyperMixing
2628
from speechbrain.nnet.normalization import LayerNorm
@@ -407,6 +409,12 @@ def __init__(
407409
num_heads=nhead,
408410
fix_tm_hidden_size=False,
409411
)
412+
elif attention_type == "RoPEMHA":
413+
self.mha_layer = RoPEMHA(
414+
num_heads=nhead,
415+
embed_dim=d_model,
416+
dropout=dropout,
417+
)
410418

411419
self.convolution_module = ConvolutionModule(
412420
d_model, kernel_size, bias, activation, dropout, causal=causal
@@ -728,7 +736,7 @@ def forward(
728736
if self.attention_type == "RelPosMHAXL":
729737
if pos_embs is None:
730738
raise ValueError(
731-
"The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
739+
f"The chosen attention type for the Conformer is {self.attention_type}. For this attention type, the positional embeddings are mandatory"
732740
)
733741

734742
output = src
@@ -794,10 +802,13 @@ def forward_streaming(
794802
The attention values.
795803
"""
796804

797-
if self.attention_type == "RelPosMHAXL":
805+
if (
806+
self.attention_type == "RelPosMHAXL"
807+
or self.attention_type == "RoPEMHA"
808+
):
798809
if pos_embs is None:
799810
raise ValueError(
800-
"The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
811+
f"The chosen attention type for the Conformer is {self.attention_type}. For this attention type, the positional embeddings are mandatory"
801812
)
802813

803814
output = src

speechbrain/lobes/models/transformer/Transformer.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Authors
33
* Jianyuan Zhong 2020
44
* Samuele Cornell 2021
5+
* Shucong Zhang 2024
56
"""
67

78
import math
@@ -137,7 +138,12 @@ def __init__(
137138
self.output_hidden_states = output_hidden_states
138139
self.layerdrop_prob = layerdrop_prob
139140

140-
assert attention_type in ["regularMHA", "RelPosMHAXL", "hypermixing"]
141+
assert attention_type in [
142+
"regularMHA",
143+
"RelPosMHAXL",
144+
"hypermixing",
145+
"RoPEMHA",
146+
]
141147
assert positional_encoding in ["fixed_abs_sine", None]
142148

143149
assert (
@@ -157,6 +163,11 @@ def __init__(
157163
d_mod CDB3 el, max_length
158164
)
159165

166+
if attention_type == "RoPEMHA":
167+
self.positional_encoding_decoder = PositionalEncoding(
168+
d_model, max_length
169+
)
170+
160171
# initialize the encoder
161172
if num_encoder_layers > 0:
162173
if custom_src_module is not None:
@@ -374,6 +385,12 @@ def __init__(
374385
num_heads=nhead,
375386
fix_tm_hidden_size=False,
376387
)
388+
elif attention_type == "RoPEMHA":
389+
self.self_att = sb.nnet.attention.RoPEMHA(
390+
d_model,
391+
nhead,
392+
dropout,
393+
)
377394

378395
if ffn_type == "regularFFN":
379396
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
@@ -704,7 +721,6 @@ def __init__(
704721
vdim=vdim,
705722
dropout=dropout,
706723
)
707-
708724
elif attention_type == "RelPosMHAXL":
709725
self.self_attn = sb.nnet.attention.RelPosMHAXL(
710726
d_model, nhead, dropout, mask_pos_future=causal
@@ -787,7 +803,6 @@ def forward(
787803
tgt1 = tgt
788804

789805
# multi-head attention over the target sequence and encoder states
790-
791806
tgt2, multihead_attention = self.multihead_attn(
792807
query=tgt1,
793808
key=memory,

speechbrain/lobes/models/transformer/TransformerASR.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
* Jianyuan Zhong 2020
55
* Titouan Parcollet 2024
66
* Luca Della Libera 2024
7+
* Shucong Zhang 2024
78
"""
89

910
from dataclasses import dataclass
@@ -362,12 +363,15 @@ def forward(self, src, tgt, wav_len=None, pad_idx=0):
362363

363364
src = self.custom_src_module(src)
364365
# add pos encoding to queries if are sinusoidal ones else
365-
if self.attention_type == "hypermixing":
366+
if (
367+
self.attention_type == "hypermixing"
368+
or self.attention_type == "RoPEMHA"
369+
):
366370
pos_embs_encoder = None
367371
elif self.attention_type == "RelPosMHAXL":
368372
pos_embs_encoder = self.positional_encoding(src)
369373
elif self.positional_encoding_type == "fixed_abs_sine":
370-
src = src + self.positional_encoding(src) # add the encodings here
374+
src = src + self.positional_encoding(src)
371375
pos_embs_encoder = None
372376

373377
outputs = self.encoder(
@@ -388,9 +392,12 @@ def forward(self, src, tgt, wav_len=None, pad_idx=0):
388392

389393
tgt = self.custom_tgt_module(tgt)
390394

391-
if self.attention_type == "RelPosMHAXL":
395+
if (
396+
self.attention_type == "RelPosMHAXL"
397+
or self.attention_type == "RoPEMHA"
398+
):
392399
tgt = tgt + self.positional_encoding_decoder(tgt)
393-
pos_embs_encoder = None # self.positional_encoding(src)
400+
pos_embs_encoder = None
394401
pos_embs_target = None
395402
elif (
396403
self.positional_encoding_type == "fixed_abs_sine"
@@ -439,15 +446,19 @@ def decode(self, tgt, encoder_out, enc_len=None):
439446
src_key_padding_mask = (1 - length_to_mask(enc_len)).bool()
440447

441448
tgt = self.custom_tgt_module(tgt)
442-
if self.attention_type == "RelPosMHAXL":
449+
450+
if (
451+
self.attention_type == "RelPosMHAXL"
452+
or self.attention_type == "RoPEMHA"
453+
):
443454
tgt = tgt + self.positional_encoding_decoder(tgt)
444-
pos_embs_encoder = None # self.positional_encoding(src)
455+
pos_embs_encoder = None
445456
pos_embs_target = None
446457
elif (
447458
self.positional_encoding_type == "fixed_abs_sine"
448459
or self.attention_type == "hypermixing"
449460
):
450-
tgt = tgt + self.positional_encoding(tgt) # add the encodings here
461+
tgt = tgt + self.positional_encoding(tgt)
451462
pos_embs_target = None
452463
pos_embs_encoder = None
453464

@@ -506,7 +517,10 @@ def encode(
506517
)
507518

508519
src = self.custom_src_module(src)
509-
if self.attention_type == "hypermixing":
520+
if (
521+
self.attention_type == "hypermixing"
522+
or self.attention_type == "RoPEMHA"
523+
):
510524
pos_embs_source = None
511525
elif self.attention_type == "RelPosMHAXL":
512526
pos_embs_source = self.positional_encoding(src)
@@ -612,6 +626,8 @@ def encode_streaming(self, src, context: TransformerASRStreamingContext):
612626
src = self.custom_src_module(src)
613627
if self.attention_type == "RelPosMHAXL":
614628
pos_embs_source = self.positional_encoding(pos_encoding_dummy)
629+
elif self.attention_type == "RoPEMHA":
630+
pos_embs_source = None
615631

616632
elif self.positional_encoding_type == "fixed_abs_sine":
617633
src = src + self.positional_encoding(pos_encoding_dummy)

0 commit comments

Comments
 (0)
0