1919
2020import torch
2121from torch import nn
22- from wenet .utils .class_utils import WENET_NORM_CLASSES
23-
2422
2523class ChunkConvolutionModule (nn .Module ):
2624 """ConvolutionModule in ChunkFormer model."""
@@ -58,11 +56,13 @@ def __init__(self,
5856 padding = 0
5957 self .lorder = kernel_size - 1
6058 elif dynamic_conv :
59+ # kernel_size should be an odd number for none causal convolution
6160 assert (kernel_size - 1 ) % 2 == 0
6261 padding = 0
63- self .lorder = (kernel_size - 1 )// 2
62+ self .lorder = (kernel_size - 1 ) // 2
6463 else :
6564 # kernel_size should be an odd number for none causal convolution
65+ assert (kernel_size - 1 ) % 2 == 0
6666 padding = (kernel_size - 1 ) // 2
6767 self .lorder = 0
6868 self .depthwise_conv = nn .Conv1d (
@@ -92,7 +92,7 @@ def __init__(self,
9292 bias = bias ,
9393 )
9494 self .activation = activation
95-
95+
9696 def forward (
9797 self ,
9898 x : torch .Tensor ,
@@ -145,13 +145,17 @@ def forward(
145145 size = self .lorder + decoding_chunk_size
146146 step = decoding_chunk_size
147147
148- n_frames_pad = (step - ((x .size (2 ) - size ) % step )) % step
149- x = torch .nn .functional .pad (x , (0 , n_frames_pad )) # (batch, 2*channel, dim + n_frames_pad)
148+ n_frames_pad = (step - ((x .size (2 ) - size ) % step )) % step
149+ # (batch, 2*channel, dim + n_frames_pad)
150+ x = torch .nn .functional .pad (x , (0 , n_frames_pad ))
150151
151152 n_chunks = ((x .size (2 ) - size ) // step ) + 1
152- x = x .unfold (- 1 , size = size , step = step ) # [B, C, n_chunks, size]
153- x = x .transpose (1 , 2 ) # [B, n_chunks, C, size]
154- x = x .reshape (- 1 , x .size (2 ), x .size (3 )) # [B * n_chunks, C, size]
153+ # [B, C, n_chunks, size]
154+ x = x .unfold (- 1 , size = size , step = step )
155+ # [B, n_chunks, C, size]
156+ x = x .transpose (1 , 2 )
157+ # [B * n_chunks, C, size]
158+ x = x .reshape (- 1 , x .size (2 ), x .size (3 ))
155159
156160 # pad right for dynamic conv
157161 x = nn .functional .pad (x , (0 , self .lorder ), 'constant' , 0.0 )
@@ -161,19 +165,22 @@ def forward(
161165 x = self .depthwise_conv (x )
162166
163167 if self .dynamic_conv :
164- # x size: [B * n_chunk, C, decoding_chunk_size]
165- x = x .reshape (- 1 , n_chunks , x .size (1 ), x .size (2 )) # [B, n_chunk, C, decoding_chunk_size]
166- x = x .transpose (1 , 2 ) # [B, C, n_chunks, decoding_chunk_size]
167- x = x .reshape (x .size (0 ), x .size (1 ), - 1 ) # [B, C, n_chunks * decoding_chunk_size]
168- x = x [..., :x .size (2 ) - n_frames_pad ] # (batch, channel, dim)
168+ # [B, n_chunk, C, decoding_chunk_size]
169+ x = x .reshape (- 1 , n_chunks , x .size (1 ), x .size (2 ))
170+ # [B, C, n_chunks, decoding_chunk_size]
171+ x = x .transpose (1 , 2 )
172+ # [B, C, n_chunks * decoding_chunk_size]
173+ x = x .reshape (x .size (0 ), x .size (1 ), - 1 )
174+ # remove padding
175+ x = x [..., :x .size (2 ) - n_frames_pad ]
169176
170177 if self .use_layer_norm :
171178 x = x .transpose (1 , 2 )
172179 x = self .activation (self .norm (x ))
173180 if self .use_layer_norm :
174181 x = x .transpose (1 , 2 )
175182 x = self .pointwise_conv2 (x )
176-
183+
177184 # mask batch padding
178185 if mask_pad .size (2 ) > 0 : # time > 0
179186 x .masked_fill_ (~ mask_pad .to (torch .bool ), 0.0 )
@@ -202,7 +209,7 @@ def forward_parallel_chunk(
202209 """
203210 # exchange the temporal dimension and the feature dimension
204211 x = x .transpose (1 , 2 ) # (#batch, channels, time)
205- lorder = self .kernel_size // 2
212+ lorder = self .kernel_size // 2
206213 chunk_size = x .shape [- 1 ]
207214 cache_t = cache .size (- 1 )
208215 if cache_t == 0 :
@@ -211,19 +218,21 @@ def forward_parallel_chunk(
211218 x = self .pointwise_conv1 (x ) # (batch, 2*channel, dim)
212219 x = nn .functional .glu (x , dim = 1 ) # (batch, channel, dim)
213220
214- #----------Overlapping Chunk Transformation-----------------------------------
215- x = x .transpose (0 , 1 ).reshape ( self .channels , - 1 ) # [C, n_chunk * T]
221+ # ----------Overlapping Chunk Transformation-----------------------------------
222+ x = x .transpose (0 , 1 ).reshape (self .channels , - 1 ) # [C, n_chunk * T]
216223 x = torch .cat ([cache , x ], dim = - 1 )
217224
218- # streaming long-form transcription is disabled if input cache is empty, only support long-form transcription and masked batch
225+ # Streaming long-form transcription is disabled if input cache is empty
219226 if cache_t > 0 :
220- new_cache = x [:, :truncated_context_size + cache .size (- 1 )][:, - cache .size (- 1 ):]
227+ new_cache = x [:, :truncated_context_size + cache .size (- 1 )]
228+ new_cache = new_cache [:, - cache .size (- 1 ):]
221229 else :
222230 new_cache = torch .zeros ((0 , 0 ))
223231
224232 x = nn .functional .pad (x , (0 , lorder ), 'constant' , 0.0 )
225- x = x .unfold (- 1 , chunk_size + 2 * lorder , chunk_size ).transpose (0 , 1 ) #[n_chunk +1, C, chunk_size + 2 * lorder]
226- #-----------------------------------------------------------------------------
233+ x = x .unfold (- 1 , chunk_size + 2 * lorder , chunk_size ).transpose (0 , 1 )
234+ # [n_chunk +1, C, chunk_size + 2 * lorder]
235+ # -----------------------------------------------------------------------------
227236
228237 if mask_pad .size (2 ) > 0 : # time > 0
229238 x = torch .where (mask_pad , x , 0 )
0 commit comments