@@ -177,22 +177,62 @@ def __init__(
177177 ) for _ in range (num_blocks )
178178 ])
179179
180+ def forward (self ,
181+ xs : torch .Tensor ,
182+ xs_lens : torch .Tensor ,
183+ decoding_chunk_size : int = 0 ,
184+ num_decoding_left_chunks : int = - 1 ,
185+ ** kwargs ):
186+ """
187+ Main forward function that dispatches to either the standard
188+ forward pass or the parallel chunk version based on the
189+ model's training mode.
190+ """
191+ if self .training :
192+ return super ().forward (
193+ xs = xs ,
194+ xs_lens = xs_lens ,
195+ decoding_chunk_size = decoding_chunk_size ,
196+ num_decoding_left_chunks = num_decoding_left_chunks ,
197+ ** kwargs
198+ ) # Call the parent class's forward method
199+ else :
200+ if decoding_chunk_size > 0 and num_decoding_left_chunks > 0 :
201+ # If both decoding_chunk_size and num_decoding_left_chunks
202+ # are set, use the parallel chunk version
203+ return self .forward_parallel_chunk (
204+ xs = xs ,
205+ xs_origin_lens = xs_lens ,
206+ chunk_size = decoding_chunk_size ,
207+ left_context_size = num_decoding_left_chunks ,
208+ right_context_size = num_decoding_left_chunks , # we assume left and right context are the same
209+ )
210+ else :
211+ # Otherwise, use the standard forward pass
212+ return super ().forward (
213+ xs = xs ,
214+ xs_lens = xs_lens ,
215+ decoding_chunk_size = decoding_chunk_size ,
216+ num_decoding_left_chunks = num_decoding_left_chunks ,
217+ ** kwargs
218+ )
219+
180220 def forward_parallel_chunk (
181221 self ,
182222 xs ,
183223 xs_origin_lens ,
184224 chunk_size : int = - 1 ,
185225 left_context_size : int = - 1 ,
186226 right_context_size : int = - 1 ,
187- att_cache : torch .Tensor = torch .zeros ((0 , 0 , 0 , 0 )),
188- cnn_cache : torch .Tensor = torch .zeros ((0 , 0 , 0 , 0 )),
227+ att_cache : torch .Tensor = torch .zeros ((0 , 0 , 0 )),
228+ cnn_cache : torch .Tensor = torch .zeros ((0 , 0 )),
189229 truncated_context_size :int = 0 ,
190230 offset : torch .Tensor = torch .zeros (0 ),
191231 ) -> Tuple [torch .Tensor , torch .Tensor ]:
192232 """Embed positions in tensor.
193233
194234 Args:
195- xs: padded input tensor (B, T, D)
235+ xs: list of B input tensors ( T, D)
196236 xs_lens: input length (B)
197237 decoding_chunk_size: decoding chunk size for dynamic chunk
198238 0: default for training, use random dynamic chunk.
@@ -208,10 +248,11 @@ def forward_parallel_chunk(
208248 masks: torch.Tensor batch padding mask after subsample
209249 (B, 1, T' ~= T/subsample_rate)
210250 """
211- assert offset .shape [0 ] == len (xs ), f"{ offset .shape [0 ]} - { len (xs )} "
251+ if offset .shape [0 ] == 0 :
252+ offset = torch .zeros (len (xs ), dtype = torch .long , device = xs_origin_lens .device )
212253
213254 # --------------------------Chunk Batching-------------------------------------------
214- subsampling = self .embed .subsampling_factor
255+ subsampling = self .embed .subsampling_rate
215256 context = self .embed .right_context + 1 # Add current frame
216257 size = (chunk_size - 1 ) * subsampling + context
217258 step = subsampling * chunk_size
@@ -258,6 +299,7 @@ def forward_parallel_chunk(
258299
259300 xs = torch .cat (x_pad , dim = 0 ).to (device )
260301 xs_lens = torch .tensor (xs_lens ).to (device )
302+ masks = ~ make_pad_mask (xs_lens , xs .size (1 )).unsqueeze (1 ) # (B, 1, T)
261303 upper_bounds = torch .cat (upper_bounds ).unsqueeze (1 ).to (device )
262304 lower_bounds = torch .cat (lower_bounds ).unsqueeze (1 ).to (device )
263305 upper_bounds_conv = torch .cat (upper_bounds_conv ).unsqueeze (1 ).to (device )
@@ -269,9 +311,7 @@ def forward_parallel_chunk(
269311 xs = self .global_cmvn (xs )
270312
271313
272- xs , pos_emb , xs_lens = self .embed (xs , xs_lens , offset = left_context_size , right_context_size = right_context_size )
273- masks = ~ make_pad_mask (xs_lens , xs .size (1 )).unsqueeze (1 ) # (B, 1, T)
274-
314+ xs , pos_emb , masks = self .embed (xs , masks , offset = left_context_size , right_context_size = right_context_size )
275315
276316 mask_pad = torch .arange (0 , conv_lorder + chunk_size + conv_lorder , device = masks .device ).unsqueeze (0 ).repeat (xs .size (0 ), 1 ) # [B, left_context_size + chunksize]
277317 mask_pad = (lower_bounds_conv <= mask_pad ) & (mask_pad < upper_bounds_conv )
@@ -280,7 +320,6 @@ def forward_parallel_chunk(
280320 att_mask = (lower_bounds <= att_mask ) & (att_mask < upper_bounds )
281321 att_mask = att_mask .flip (- 1 ).unsqueeze (1 )
282322
283-
284323 r_att_cache = []
285324 r_cnn_cache = []
286325 for i , layer in enumerate (self .encoders ):
@@ -296,19 +335,35 @@ def forward_parallel_chunk(
296335 r_att_cache .append (new_att_cache )
297336 r_cnn_cache .append (new_cnn_cache )
298337
299- del att_cache
300- del cnn_cache
301338 if self .normalize_before :
302339 xs = self .after_norm (xs )
303340
304- xs_lens = self .embed .calc_length (xs_origin_lens )
305- offset += xs_lens
306-
307341
308342 # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
309343 # ? may be larger than cache_t1, it depends on required_cache_size
310344 r_att_cache = torch .stack (r_att_cache , dim = 0 )
311345 # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
312346 r_cnn_cache = torch .stack (r_cnn_cache , dim = 0 )
313- return xs , xs_lens , n_chunks , r_att_cache , r_cnn_cache , offset
314-
347+
348+ # It would be no need to reconstruct (padding) in greedy search
349+ # but for compatibility with Wenet, we reconstruct it here
350+ xs_lens = self .embed .calc_length (xs_origin_lens )
351+ xs , masks = self .reconstruct (xs , xs_lens , n_chunks )
352+ offset += xs_lens
353+
354+ return xs , masks
355+
356+ def reconstruct (
357+ self ,
358+ xs ,
359+ xs_lens ,
360+ n_chunks
361+ ):
362+ xs = xs .split (n_chunks , dim = 0 )
363+ xs = [x .reshape (- 1 , self ._output_size )[:x_len ] for x , x_len in zip (xs , xs_lens )]
364+
365+ xs = torch .nn .utils .rnn .pad_sequence (xs ,
366+ batch_first = True ,
367+ padding_value = 0 )
368+ masks = ~ make_pad_mask (xs_lens , xs .size (1 )).unsqueeze (1 ).to (xs .device ) # (B, 1, T)
369+ return xs , masks
0 commit comments