Skip to content

Commit 24a19c5

Browse files
author
khanld
committed
[feat] add masked batch and limited context decoding
1 parent 4739a05 commit 24a19c5

File tree

7 files changed

+114
-45
lines changed

7 files changed

+114
-45
lines changed

wenet/bin/recognize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def main():
281281
with torch.no_grad():
282282
for batch_idx, batch in enumerate(test_data_loader):
283283
keys = batch["keys"]
284-
feats = batch["feats"].to(device)
284+
feats = batch["feats"].to(device) if type(batch["feats"]) is torch.Tensor else batch["feats"]
285285
target = batch["target"].to(device)
286286
feats_lengths = batch["feats_lengths"].to(device)
287287
target_lengths = batch["target_lengths"].to(device)

wenet/chunkformer/attention.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@ def forward(self, query: torch.Tensor,
103103
k = torch.cat([key_cache, k], dim=2)
104104
v = torch.cat([value_cache, v], dim=2)
105105

106-
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
107-
# non-trivial to calculate `next_cache_start` here.
108-
109-
new_cache = torch.cat((k, v), dim=-1)
106+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
107+
# non-trivial to calculate `next_cache_start` here.
108+
new_cache = torch.cat((k, v), dim=-1)
109+
else:
110+
# streaming long-form transcription is disabled if input cache is empty, only support long-form transcription and masked batch
111+
new_cache = cache
110112

111113
n_batch_pos = pos_emb.size(0)
112114
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
@@ -139,7 +141,7 @@ def forward_parallel_chunk(self, query: torch.Tensor,
139141
key: torch.Tensor, value: torch.Tensor,
140142
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
141143
pos_emb: torch.Tensor = torch.empty(0),
142-
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
144+
cache: torch.Tensor = torch.zeros((0, 0, 0)),
143145
right_context_size: int = 0,
144146
left_context_size: int = 0,
145147
truncated_context_size: int = 0
@@ -153,20 +155,20 @@ def forward_parallel_chunk(self, query: torch.Tensor,
153155
(#batch, time1, time2), (0, 0, 0) means fake mask.
154156
pos_emb (torch.Tensor): Positional embedding tensor
155157
(#batch, time2, size).
156-
cache (torch.Tensor): Cache tensor (B, 1, head, cache_t, d_k * 2),
157-
where `cache_t == chunk_size * num_decoding_left_chunks`
158+
cache (torch.Tensor): Cache tensor (cache_t, head, d_k * 2),
159+
where `cache_t == left_context_size`
158160
and `head * d_k == size`
159161
Returns:
160162
torch.Tensor: Output tensor (#batch, time1, d_model).
161-
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
162-
where `cache_t == chunk_size * num_decoding_left_chunks`
163+
torch.Tensor: Cache tensor (cache_t, head, d_k * 2)
164+
where `cache_t == left_context_size`
163165
and `head * d_k == size`
164166
"""
165167
q, k, v = self.forward_qkv(query, key, value)
166168

167169
q = q.transpose(1, 2) # (batch, time1, head, d_k)
168-
169-
if cache.size(2) <= 0:
170+
cache_t = cache.size(0)
171+
if cache_t == 0:
170172
cache = torch.zeros((left_context_size, self.h, self.d_k * 2), device=q.device, dtype=q.dtype)
171173

172174
kv = torch.cat([k, v], dim=-1) # (B, head, time1, d_k * 2),
@@ -175,7 +177,11 @@ def forward_parallel_chunk(self, query: torch.Tensor,
175177

176178
#----------Overlapping Chunk Transformation-----------------------------------
177179
kv = torch.cat([cache, kv], dim=0)
178-
new_cache = kv[:truncated_context_size + cache.size(0)][-cache.size(0):].cpu()
180+
181+
if cache_t > 0:
182+
new_cache = kv[:truncated_context_size + cache.size(0)][-cache.size(0):]
183+
else:
184+
new_cache = torch.zeros((0, 0, 0), device=q.device, dtype=q.dtype)
179185
kv = torch.nn.functional.pad(kv, (0, 0, 0, 0, 0, right_context_size))
180186
kv = kv.unfold(0, left_context_size + q.shape[1] + right_context_size, q.shape[1])
181187
#-----------------------------------------------------------------------------

wenet/chunkformer/convolution.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,26 +185,27 @@ def forward_parallel_chunk(
185185
self,
186186
x: torch.Tensor,
187187
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
188-
cache: torch.Tensor = torch.zeros((0, 0, 0)),
188+
cache: torch.Tensor = torch.zeros((0, 0)),
189189
truncated_context_size: int = 0
190190

191191
) -> Tuple[torch.Tensor, torch.Tensor]:
192192
"""Compute convolution module.
193193
Args:
194-
x (torch.Tensor): Input tensor (#batch, time, channels).
194+
x (torch.Tensor): Input tensor (time, channels).
195195
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
196196
(0, 0, 0) means fake mask.
197197
cache (torch.Tensor): left context cache, it is only
198-
used in causal convolution (#batch, channels, cache_t),
199-
(0, 0, 0) meas fake cache.
198+
used in causal convolution (channels, cache_t),
199+
(0, 0) meas fake cache.
200200
Returns:
201-
torch.Tensor: Output tensor (#batch, time, channels).
201+
torch.Tensor: Output tensor (time, channels).
202202
"""
203203
# exchange the temporal dimension and the feature dimension
204204
x = x.transpose(1, 2) # (#batch, channels, time)
205205
lorder = self.kernel_size//2
206206
chunk_size = x.shape[-1]
207-
if cache.size(0) == 0:
207+
cache_t = cache.size(-1)
208+
if cache_t == 0:
208209
cache = torch.zeros(self.channels, lorder).to(x.device)
209210
# GLU mechanism
210211
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
@@ -213,9 +214,15 @@ def forward_parallel_chunk(
213214
#----------Overlapping Chunk Transformation-----------------------------------
214215
x = x.transpose(0, 1).reshape( self.channels, -1) # [C, n_chunk * T]
215216
x = torch.cat([cache, x], dim=-1)
216-
new_cache = x[:, :truncated_context_size + cache.size(-1)][:, -cache.size(-1):].cpu()
217+
218+
# streaming long-form transcription is disabled if input cache is empty, only support long-form transcription and masked batch
219+
if cache_t > 0:
220+
new_cache = x[:, :truncated_context_size + cache.size(-1)][:, -cache.size(-1):]
221+
else:
222+
new_cache = torch.zeros((0, 0))
223+
217224
x = nn.functional.pad(x, (0, lorder), 'constant', 0.0)
218-
x = x.unfold(-1, chunk_size + 2 * lorder, chunk_size).transpose(0, 1) #[n_chunk +1, C, cnn_cache_size]
225+
x = x.unfold(-1, chunk_size + 2 * lorder, chunk_size).transpose(0, 1) #[n_chunk +1, C, chunk_size + 2 * lorder]
219226
#-----------------------------------------------------------------------------
220227

221228
if mask_pad.size(2) > 0: # time > 0
@@ -232,7 +239,6 @@ def forward_parallel_chunk(
232239
x = self.pointwise_conv2(x)
233240
# mask batch padding
234241
if mask_pad.size(2) > 0: # time > 0
235-
# x.masked_fill_(~mask_pad[:, :, self.lorder:], 0.0)
236-
x.masked_fill_(~mask_pad[:, :, self.lorder:-self.lorder], 0.0)
242+
x.masked_fill_(~mask_pad[:, :, lorder:-lorder], 0.0)
237243

238244
return x.transpose(1, 2), new_cache

wenet/chunkformer/encoder.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -177,22 +177,62 @@ def __init__(
177177
) for _ in range(num_blocks)
178178
])
179179

180+
def forward(self,
181+
xs: torch.Tensor,
182+
xs_lens: torch.Tensor,
183+
decoding_chunk_size: int = 0,
184+
num_decoding_left_chunks: int = -1,
185+
**kwargs):
186+
"""
187+
Main forward function that dispatches to either the standard
188+
forward pass or the parallel chunk version based on the
189+
model's training mode.
190+
"""
191+
if self.training:
192+
return super().forward(
193+
xs=xs,
194+
xs_lens=xs_lens,
195+
decoding_chunk_size=decoding_chunk_size,
196+
num_decoding_left_chunks=num_decoding_left_chunks,
197+
**kwargs
198+
) # Call the parent class's forward method
199+
else:
200+
if decoding_chunk_size > 0 and num_decoding_left_chunks > 0:
201+
# If both decoding_chunk_size and num_decoding_left_chunks
202+
# are set, use the parallel chunk version
203+
return self.forward_parallel_chunk(
204+
xs=xs,
205+
xs_origin_lens=xs_lens,
206+
chunk_size=decoding_chunk_size,
207+
left_context_size=num_decoding_left_chunks,
208+
right_context_size=num_decoding_left_chunks, # we assume left and right context are the same
209+
)
210+
else:
211+
# Otherwise, use the standard forward pass
212+
return super().forward(
213+
xs=xs,
214+
xs_lens=xs_lens,
215+
decoding_chunk_size=decoding_chunk_size,
216+
num_decoding_left_chunks=num_decoding_left_chunks,
217+
**kwargs
218+
)
219+
180220
def forward_parallel_chunk(
181221
self,
182222
xs,
183223
xs_origin_lens,
184224
chunk_size: int = -1,
185225
left_context_size: int = -1,
186226
right_context_size: int = -1,
187-
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
188-
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
227+
att_cache: torch.Tensor = torch.zeros((0, 0, 0)),
228+
cnn_cache: torch.Tensor = torch.zeros((0, 0)),
189229
truncated_context_size:int = 0,
190230
offset: torch.Tensor = torch.zeros(0),
191231
) -> Tuple[torch.Tensor, torch.Tensor]:
192232
"""Embed positions in tensor.
193233
194234
Args:
195-
xs: padded input tensor (B, T, D)
235+
xs: list of B input tensors (T, D)
196236
xs_lens: input length (B)
197237
decoding_chunk_size: decoding chunk size for dynamic chunk
198238
0: default for training, use random dynamic chunk.
@@ -208,10 +248,11 @@ def forward_parallel_chunk(
208248
masks: torch.Tensor batch padding mask after subsample
209249
(B, 1, T' ~= T/subsample_rate)
210250
"""
211-
assert offset.shape[0] == len(xs), f"{offset.shape[0]} - {len(xs)}"
251+
if offset.shape[0] == 0:
252+
offset = torch.zeros(len(xs), dtype=torch.long, device=xs_origin_lens.device)
212253

213254
# --------------------------Chunk Batching-------------------------------------------
214-
subsampling = self.embed.subsampling_factor
255+
subsampling = self.embed.subsampling_rate
215256
context = self.embed.right_context + 1 # Add current frame
216257
size = (chunk_size - 1) * subsampling + context
217258
step = subsampling * chunk_size
@@ -258,6 +299,7 @@ def forward_parallel_chunk(
258299

259300
xs = torch.cat(x_pad, dim=0).to(device)
260301
xs_lens = torch.tensor(xs_lens).to(device)
302+
masks = ~make_pad_mask(xs_lens, xs.size(1)).unsqueeze(1) # (B, 1, T)
261303
upper_bounds = torch.cat(upper_bounds).unsqueeze(1).to(device)
262304
lower_bounds = torch.cat(lower_bounds).unsqueeze(1).to(device)
263305
upper_bounds_conv = torch.cat(upper_bounds_conv).unsqueeze(1).to(device)
@@ -269,9 +311,7 @@ def forward_parallel_chunk(
269311
xs = self.global_cmvn(xs)
270312

271313

272-
xs, pos_emb, xs_lens = self.embed(xs, xs_lens, offset=left_context_size, right_context_size=right_context_size)
273-
masks = ~make_pad_mask(xs_lens, xs.size(1)).unsqueeze(1) # (B, 1, T)
274-
314+
xs, pos_emb, masks = self.embed(xs, masks, offset=left_context_size, right_context_size=right_context_size)
275315

276316
mask_pad = torch.arange(0, conv_lorder + chunk_size + conv_lorder, device=masks.device).unsqueeze(0).repeat(xs.size(0), 1) # [B, left_context_size + chunksize]
277317
mask_pad = (lower_bounds_conv <= mask_pad) & (mask_pad < upper_bounds_conv)
@@ -280,7 +320,6 @@ def forward_parallel_chunk(
280320
att_mask = (lower_bounds <= att_mask) & (att_mask < upper_bounds)
281321
att_mask = att_mask.flip(-1).unsqueeze(1)
282322

283-
284323
r_att_cache = []
285324
r_cnn_cache = []
286325
for i, layer in enumerate(self.encoders):
@@ -296,19 +335,35 @@ def forward_parallel_chunk(
296335
r_att_cache.append(new_att_cache)
297336
r_cnn_cache.append(new_cnn_cache)
298337

299-
del att_cache
300-
del cnn_cache
301338
if self.normalize_before:
302339
xs = self.after_norm(xs)
303340

304-
xs_lens = self.embed.calc_length(xs_origin_lens)
305-
offset += xs_lens
306-
307341

308342
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
309343
# ? may be larger than cache_t1, it depends on required_cache_size
310344
r_att_cache = torch.stack(r_att_cache, dim=0)
311345
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
312346
r_cnn_cache = torch.stack(r_cnn_cache, dim=0)
313-
return xs, xs_lens, n_chunks, r_att_cache, r_cnn_cache, offset
314-
347+
348+
# It would be no need to reconstruct (padding) in greedy search
349+
# but for compatibility with Wenet, we reconstruct it here
350+
xs_lens = self.embed.calc_length(xs_origin_lens)
351+
xs, masks = self.reconstruct(xs, xs_lens, n_chunks)
352+
offset += xs_lens
353+
354+
return xs, masks
355+
356+
def reconstruct(
357+
self,
358+
xs,
359+
xs_lens,
360+
n_chunks
361+
):
362+
xs = xs.split(n_chunks, dim=0)
363+
xs = [x.reshape(-1, self._output_size)[:x_len] for x, x_len in zip(xs, xs_lens)]
364+
365+
xs = torch.nn.utils.rnn.pad_sequence(xs,
366+
batch_first=True,
367+
padding_value=0)
368+
masks = ~make_pad_mask(xs_lens, xs.size(1)).unsqueeze(1).to(xs.device) # (B, 1, T)
369+
return xs, masks

wenet/dataset/dataset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,24 +132,26 @@ def Dataset(data_type,
132132

133133
batch_conf = conf.get('batch_conf', {})
134134
batch_type = batch_conf.get('batch_type', 'static')
135+
pad_feat = batch_conf.get('pad_feat', 'True')
136+
135137
assert batch_type in ['static', 'bucket', 'dynamic']
136138
if batch_type == 'static':
137139
assert 'batch_size' in batch_conf
138140
batch_size = batch_conf.get('batch_size', 16)
139-
dataset = dataset.batch(batch_size, wrapper_class=processor.padding)
141+
dataset = dataset.batch(batch_size, wrapper_class=lambda batch: processor.padding(batch, pad_feat))
140142
elif batch_type == 'bucket':
141143
assert 'bucket_boundaries' in batch_conf
142144
assert 'bucket_batch_sizes' in batch_conf
143145
dataset = dataset.bucket_by_sequence_length(
144146
processor.feats_length_fn,
145147
batch_conf['bucket_boundaries'],
146148
batch_conf['bucket_batch_sizes'],
147-
wrapper_class=processor.padding)
149+
wrapper_class=lambda batch: processor.padding(batch, pad_feat))
148150
else:
149151
max_frames_in_batch = batch_conf.get('max_frames_in_batch', 12000)
150152
dataset = dataset.dynamic_batch(
151153
processor.DynamicBatchWindow(max_frames_in_batch),
152-
wrapper_class=processor.padding,
154+
wrapper_class=lambda batch: processor.padding(batch, pad_feat)
153155
)
154156

155157
return dataset

wenet/dataset/processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def spec_trim(sample, max_t=20):
527527
return sample
528528

529529

530-
def padding(data):
530+
def padding(data, pad_feat=True):
531531
""" Padding the data into training data
532532
533533
Args:
@@ -565,7 +565,7 @@ def padding(data):
565565

566566
batch = {
567567
"keys": sorted_keys,
568-
"feats": padded_feats,
568+
"feats": padded_feats if pad_feat else sorted_feats,
569569
"target": padding_labels,
570570
"feats_lengths": feats_lengths,
571571
"target_lengths": label_lengths,

wenet/transformer/asr_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def decode(
301301
302302
Returns: dict results of all decoding methods
303303
"""
304-
assert speech.shape[0] == speech_lengths.shape[0]
304+
assert len(speech) == len(speech_lengths)
305305
assert decoding_chunk_size != 0
306306
encoder_out, encoder_mask = self._forward_encoder(
307307
speech, speech_lengths, decoding_chunk_size,

0 commit comments

Comments
 (0)