Skip to content

Commit 3e027ee

Browse files
author
khanld
committed
[feat] add chunkformer training config and results
1 parent 55781fc commit 3e027ee

File tree

5 files changed

+164
-147
lines changed

5 files changed

+164
-147
lines changed

examples/librispeech/s0/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,37 @@ test other
313313
| ctc_greedy_search | 8.73 | 9.82 | 9.83 |
314314
| ctc prefix beam search | 8.70 | 9.81 | 9.79 |
315315
| attention rescoring | 8.05 | 9.08 | 9.10 |
316+
317+
318+
## ChunkFormer U2++ Result
319+
320+
* Model info:
321+
* Encoder Params: 32,356,096
322+
* Downsample rate: dw_striding 8x
323+
* encoder_dim 256, head 4, linear_units 2048
324+
* num_blocks 12, cnn_module_kernel 15
325+
* Feature info: using fbank feature, cmvn, dither, online speed perturb
326+
* Training info:
327+
* train_u2++_chunkformer_small.yaml, kernel size 15
328+
* dynamic batch size 120.000, 2 gpu, acc_grad 4, 200 epochs, dither 1.0
329+
* adamw, lr 1e-3, warmuplr, warmup_steps: 25000
330+
* specaug and speed perturb
331+
* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 100, beam size 10
332+
333+
#### Full context training -> Chunk context inferencing:
334+
⚠️ Attention Decoder does **not** support chunk-context inference due to cross-attention mismatch with full context training. Chunk-context training is required to resolve this mismatch.
335+
336+
| Decoding Mode | Dev Clean | Dev Other | Test Clean | Test Other |
337+
|------------------------|-----------|-----------|------------|------------|
338+
| CTC Greedy Search | 3.05 | 8.84 | 3.27 | 8.54 |
339+
| CTC Prefix Beam Search | 3.04 | 8.83 | 3.26 | 8.54 |
340+
| Attention Decoder | 4.58 | 9.62 | 5.07 | 9.22 |
341+
| Attention Rescoring | 2.83 | 8.39 | 2.97 | 8.02 |
342+
343+
#### Full context training -> Full context inferencing:
344+
| Decoding Mode | Dev Clean | Dev Other | Test Clean | Test Other |
345+
|------------------------|-----------|-----------|------------|------------|
346+
| CTC Greedy Search | 3.08 | 8.82 | 3.24 | 8.55 |
347+
| CTC Prefix Beam Search | 3.06 | 8.80 | 3.23 | 8.53 |
348+
| Attention Decoder | 2.92 | 8.28 | 3.03 | 8.05 |
349+
| Attention Rescoring | 2.80 | 8.37 | 2.94 | 8.03 |
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# network architecture
2+
# encoder related
3+
encoder: chunkformer
4+
encoder_conf:
5+
output_size: 256 # dimension of attention
6+
attention_heads: 4
7+
linear_units: 2048 # the number of units of position-wise feed forward
8+
num_blocks: 12 # the number of encoder blocks
9+
dropout_rate: 0.1
10+
positional_dropout_rate: 0.1
11+
attention_dropout_rate: 0.1
12+
input_layer: dw_striding # encoder input type, you can chose conv2d, conv2d6 and conv2d8
13+
normalize_before: true
14+
cnn_module_kernel: 15
15+
use_cnn_module: True
16+
activation_type: 'swish'
17+
pos_enc_layer_type: 'chunk_rel_pos'
18+
selfattention_layer_type: 'chunk_rel_seflattn'
19+
dynamic_conv: false
20+
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
21+
22+
# decoder related
23+
decoder: bitransformer
24+
decoder_conf:
25+
attention_heads: 4
26+
linear_units: 2048
27+
num_blocks: 3
28+
r_num_blocks: 3
29+
dropout_rate: 0.1
30+
positional_dropout_rate: 0.1
31+
self_attention_dropout_rate: 0.1
32+
src_attention_dropout_rate: 0.1
33+
34+
tokenizer: bpe
35+
tokenizer_conf:
36+
symbol_table_path: 'data/lang_char/train_960_bpe5000_units.txt'
37+
split_with_space: false
38+
bpe_path: 'data/lang_char/train_960_bpe5000.model'
39+
non_lang_syms_path: null
40+
is_multilingual: false
41+
num_languages: 1
42+
special_tokens:
43+
<blank>: 0
44+
<unk>: 1
45+
<sos>: 2
46+
<eos>: 2
47+
48+
ctc: ctc
49+
ctc_conf:
50+
ctc_blank_id: 0
51+
52+
cmvn: global_cmvn
53+
cmvn_conf:
54+
cmvn_file: 'data/train_960/global_cmvn'
55+
is_json_cmvn: true
56+
57+
# hybrid CTC/attention
58+
model: asr_model
59+
model_conf:
60+
ctc_weight: 0.3
61+
lsm_weight: 0.1 # label smoothing option
62+
length_normalized_loss: false
63+
reverse_weight: 0.3
64+
65+
# dataset related
66+
dataset: asr
67+
dataset_conf:
68+
filter_conf:
69+
max_length: 40960
70+
min_length: 0
71+
token_max_length: 400
72+
token_min_length: 1
73+
# min_output_input_ratio: 0.0005
74+
# max_output_input_ratio: 0.1
75+
resample_conf:
76+
resample_rate: 16000
77+
speed_perturb: true
78+
fbank_conf:
79+
num_mel_bins: 80
80+
frame_shift: 10
81+
frame_length: 25
82+
dither: 1.0
83+
spec_aug: true
84+
spec_aug_conf:
85+
num_t_mask: 2
86+
num_f_mask: 2
87+
max_t: 50
88+
max_f: 10
89+
spec_sub: false
90+
spec_sub_conf:
91+
num_t_sub: 3
92+
max_t: 30
93+
shuffle: true
94+
shuffle_conf:
95+
shuffle_size: 1000
96+
sort: false
97+
sort_conf:
98+
sort_size: 2000 # sort_size should be less than shuffle_size
99+
batch_conf:
100+
batch_type: 'dynamic' # static or dynamic
101+
max_frames_in_batch: 120000
102+
# At inference, pad_feat should be False to activate
103+
# masked batch and chunk context decoding
104+
pad_feat: True
105+
106+
grad_clip: 5
107+
accum_grad: 4
108+
max_epoch: 200
109+
log_interval: 100
110+
111+
optim: adamw
112+
optim_conf:
113+
lr: 0.001
114+
scheduler: warmuplr # pytorch v1.1.0+ required
115+
scheduler_conf:
116+
warmup_steps: 25000

wenet/chunkformer/embedding.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,14 @@ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> No
2525
self.max_len = max_len
2626
self.extend_pe(max_len)
2727

28-
def extend_pe(self, size: int, left_context: Union[int, torch.Tensor] = 0) -> None:
28+
def extend_pe(self, size: int) -> None:
2929
"""Reset the positional encodings."""
30-
x_size_1 = size + left_context
31-
3230
# Suppose `i` means to the position of query vector and `j` means the
3331
# position of key vector. We use position relative positions when keys
3432
# are to the left (i>j) and negative relative positions otherwise (i<j).
35-
pe_positive = torch.zeros(x_size_1, self.d_model)
36-
pe_negative = torch.zeros(x_size_1, self.d_model)
37-
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
33+
pe_positive = torch.zeros(size, self.d_model)
34+
pe_negative = torch.zeros(size, self.d_model)
35+
position = torch.arange(0, size, dtype=torch.float32).unsqueeze(1)
3836
div_term = torch.exp(
3937
torch.arange(0, self.d_model, 2, dtype=torch.float32)
4038
* -(math.log(10000.0) / self.d_model)

wenet/chunkformer/encoder.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,8 @@ def forward_parallel_chunk(
257257

258258
conv_lorder = self.cnn_module_kernel // 2
259259

260-
upper_bounds = []
261-
lower_bounds = []
260+
upper_bounds_att = []
261+
lower_bounds_att = []
262262
upper_bounds_conv = []
263263
lower_bounds_conv = []
264264
x_pad = []
@@ -279,13 +279,13 @@ def forward_parallel_chunk(
279279

280280
# attention boundaries
281281
max_len = 1 + (xs_origin_len - context) // subsampling
282-
upper_bound = chunk_size + right_context_size + torch.arange(
282+
upper_bound_att = chunk_size + right_context_size + torch.arange(
283283
0,
284284
1 + (xs_origin_len + n_frames_pad - context) // subsampling,
285285
1 + (size - context) // subsampling, device=device
286286
)
287-
lower_bound = upper_bound - max_len
288-
upper_bound += offs
287+
lower_bound_att = upper_bound_att - max_len
288+
upper_bound_att += offs
289289

290290
# convolution boundaries
291291
upper_bound_conv = chunk_size + conv_lorder + torch.arange(
@@ -301,8 +301,8 @@ def forward_parallel_chunk(
301301

302302

303303
xs_lens += [size] * (n_chunk - 1) + [size - n_frames_pad]
304-
upper_bounds.append(upper_bound)
305-
lower_bounds.append(lower_bound)
304+
upper_bounds_att.append(upper_bound_att)
305+
lower_bounds_att.append(lower_bound_att)
306306
upper_bounds_conv.append(upper_bound_conv)
307307
lower_bounds_conv.append(lower_bound_conv)
308308
x_pad.append(x)
@@ -312,8 +312,8 @@ def forward_parallel_chunk(
312312
xs = torch.cat(x_pad, dim=0).to(device)
313313
xs_lens = torch.tensor(xs_lens).to(device)
314314
masks = ~make_pad_mask(xs_lens, xs.size(1)).unsqueeze(1) # (B, 1, T)
315-
upper_bounds = torch.cat(upper_bounds).unsqueeze(1).to(device)
316-
lower_bounds = torch.cat(lower_bounds).unsqueeze(1).to(device)
315+
upper_bounds_att = torch.cat(upper_bounds_att).unsqueeze(1).to(device)
316+
lower_bounds_att = torch.cat(lower_bounds_att).unsqueeze(1).to(device)
317317
upper_bounds_conv = torch.cat(upper_bounds_conv).unsqueeze(1).to(device)
318318
lower_bounds_conv = torch.cat(lower_bounds_conv).unsqueeze(1).to(device)
319319

@@ -346,7 +346,7 @@ def forward_parallel_chunk(
346346
left_context_size + chunk_size + right_context_size,
347347
device=masks.device
348348
).unsqueeze(0).repeat(xs.size(0), 1)
349-
att_mask = (lower_bounds <= att_mask) & (att_mask < upper_bounds)
349+
att_mask = (lower_bounds_att <= att_mask) & (att_mask < upper_bounds_att)
350350
att_mask = att_mask.flip(-1).unsqueeze(1)
351351

352352
r_att_cache = []

wenet/chunkformer/encoder_layer copy.py

Lines changed: 0 additions & 131 deletions
This file was deleted.

0 commit comments

Comments
 (0)