Skip to content

Commit 55781fc

Browse files
author
khanld
committed
[refactor] flake8 refactor for chunkformer
1 parent 24a19c5 commit 55781fc

File tree

8 files changed

+207
-146
lines changed

8 files changed

+207
-146
lines changed

wenet/bin/recognize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,10 @@ def main():
281281
with torch.no_grad():
282282
for batch_idx, batch in enumerate(test_data_loader):
283283
keys = batch["keys"]
284-
feats = batch["feats"].to(device) if type(batch["feats"]) is torch.Tensor else batch["feats"]
284+
if type(batch["feats"]) is torch.Tensor:
285+
feats = batch["feats"].to(device)
286+
else:
287+
feats = batch["feats"]
285288
target = batch["target"].to(device)
286289
feats_lengths = batch["feats_lengths"].to(device)
287290
target_lengths = batch["target_lengths"].to(device)

wenet/chunkformer/attention.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Multi-Head Attention layer definition."""
22

33
import math
4-
from typing import Tuple, Union
4+
from typing import Tuple
55

66
import torch
77
from torch import nn
@@ -107,7 +107,6 @@ def forward(self, query: torch.Tensor,
107107
# non-trivial to calculate `next_cache_start` here.
108108
new_cache = torch.cat((k, v), dim=-1)
109109
else:
110-
# streaming long-form transcription is disabled if input cache is empty, only support long-form transcription and masked batch
111110
new_cache = cache
112111

113112
n_batch_pos = pos_emb.size(0)
@@ -137,15 +136,18 @@ def forward(self, query: torch.Tensor,
137136

138137
return self.forward_attention(v, scores, mask), new_cache
139138

140-
def forward_parallel_chunk(self, query: torch.Tensor,
141-
key: torch.Tensor, value: torch.Tensor,
142-
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
143-
pos_emb: torch.Tensor = torch.empty(0),
144-
cache: torch.Tensor = torch.zeros((0, 0, 0)),
145-
right_context_size: int = 0,
146-
left_context_size: int = 0,
147-
truncated_context_size: int = 0
148-
) -> Tuple[torch.Tensor, torch.Tensor]:
139+
def forward_parallel_chunk(
140+
self,
141+
query: torch.Tensor,
142+
key: torch.Tensor,
143+
value: torch.Tensor,
144+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
145+
pos_emb: torch.Tensor = torch.empty(0),
146+
cache: torch.Tensor = torch.zeros((0, 0, 0)),
147+
right_context_size: int = 0,
148+
left_context_size: int = 0,
149+
truncated_context_size: int = 0
150+
) -> Tuple[torch.Tensor, torch.Tensor]:
149151
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
150152
Args:
151153
query (torch.Tensor): Query tensor (#batch, time1, size).
@@ -169,32 +171,39 @@ def forward_parallel_chunk(self, query: torch.Tensor,
169171
q = q.transpose(1, 2) # (batch, time1, head, d_k)
170172
cache_t = cache.size(0)
171173
if cache_t == 0:
172-
cache = torch.zeros((left_context_size, self.h, self.d_k * 2), device=q.device, dtype=q.dtype)
174+
cache = torch.zeros(
175+
(left_context_size, self.h, self.d_k * 2),
176+
device=q.device, dtype=q.dtype
177+
)
178+
# (B, head, time1, d_k * 2),
179+
kv = torch.cat([k, v], dim=-1)
180+
# [n_chunk * chunk_size, head, F]
181+
kv = kv.transpose(1, 2).reshape(-1, self.h, self.d_k * 2)
173182

174-
kv = torch.cat([k, v], dim=-1) # (B, head, time1, d_k * 2),
175-
kv = kv.transpose(1, 2).reshape(-1, self.h, self.d_k * 2) # [n_chunk * chunk_size, head, F]
176183

177-
178-
#----------Overlapping Chunk Transformation-----------------------------------
184+
# ----------Overlapping Chunk Transformation-----------------------------------
179185
kv = torch.cat([cache, kv], dim=0)
180186

181187
if cache_t > 0:
182188
new_cache = kv[:truncated_context_size + cache.size(0)][-cache.size(0):]
183189
else:
190+
# Streaming long-form transcription is disabled if input cache is empty,
184191
new_cache = torch.zeros((0, 0, 0), device=q.device, dtype=q.dtype)
185192
kv = torch.nn.functional.pad(kv, (0, 0, 0, 0, 0, right_context_size))
186-
kv = kv.unfold(0, left_context_size + q.shape[1] + right_context_size, q.shape[1])
187-
#-----------------------------------------------------------------------------
188-
193+
kv = kv.unfold(
194+
0,
195+
left_context_size + q.shape[1] + right_context_size,
196+
q.shape[1]
197+
)
198+
# -----------------------------------------------------------------------------
189199

190-
kv = kv.transpose(2, 3) #[n_chunk + 1, head, F, left_context_size]
200+
# [n_chunk + 1, head, F, left_context_size]
201+
kv = kv.transpose(2, 3)
191202
k, v = torch.split(
192203
kv, kv.size(-1) // 2, dim=-1)
193-
204+
194205
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
195206
# non-trivial to calculate `next_cache_start` here.
196-
197-
198207
n_batch_pos = pos_emb.size(0)
199208
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
200209
p = p.transpose(1, 2) # (batch, head, time1, d_k)
@@ -216,14 +225,10 @@ def forward_parallel_chunk(self, query: torch.Tensor,
216225
# (batch, head, time1, time2)
217226
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
218227

219-
# Remove rel_shift since it is useless in speech recognition,
220-
# and it requires special attention for streaming.
221-
228+
# Add relative shift with right context inclusion,it can stream
222229
matrix_bd = self.rel_shift(matrix_bd, left_context_size, right_context_size)
223230

224231
scores = (matrix_ac + matrix_bd) / math.sqrt(
225232
self.d_k) # (batch, head, time1, time2)
226233

227-
228234
return self.forward_attention(v, scores, mask), new_cache
229-

wenet/chunkformer/convolution.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
import torch
2121
from torch import nn
22-
from wenet.utils.class_utils import WENET_NORM_CLASSES
23-
2422

2523
class ChunkConvolutionModule(nn.Module):
2624
"""ConvolutionModule in ChunkFormer model."""
@@ -58,11 +56,13 @@ def __init__(self,
5856
padding = 0
5957
self.lorder = kernel_size - 1
6058
elif dynamic_conv:
59+
# kernel_size should be an odd number for none causal convolution
6160
assert (kernel_size - 1) % 2 == 0
6261
padding = 0
63-
self.lorder = (kernel_size - 1)//2
62+
self.lorder = (kernel_size - 1) // 2
6463
else:
6564
# kernel_size should be an odd number for none causal convolution
65+
assert (kernel_size - 1) % 2 == 0
6666
padding = (kernel_size - 1) // 2
6767
self.lorder = 0
6868
self.depthwise_conv = nn.Conv1d(
@@ -92,7 +92,7 @@ def __init__(self,
9292
bias=bias,
9393
)
9494
self.activation = activation
95-
95+
9696
def forward(
9797
self,
9898
x: torch.Tensor,
@@ -145,13 +145,17 @@ def forward(
145145
size = self.lorder + decoding_chunk_size
146146
step = decoding_chunk_size
147147

148-
n_frames_pad = (step - ((x.size(2) - size) % step)) % step
149-
x = torch.nn.functional.pad(x, (0, n_frames_pad)) # (batch, 2*channel, dim + n_frames_pad)
148+
n_frames_pad = (step - ((x.size(2) - size) % step)) % step
149+
# (batch, 2*channel, dim + n_frames_pad)
150+
x = torch.nn.functional.pad(x, (0, n_frames_pad))
150151

151152
n_chunks = ((x.size(2) - size) // step) + 1
152-
x = x.unfold(-1, size=size, step=step) # [B, C, n_chunks, size]
153-
x = x.transpose(1, 2) # [B, n_chunks, C, size]
154-
x = x.reshape(-1, x.size(2), x.size(3)) # [B * n_chunks, C, size]
153+
# [B, C, n_chunks, size]
154+
x = x.unfold(-1, size=size, step=step)
155+
# [B, n_chunks, C, size]
156+
x = x.transpose(1, 2)
157+
# [B * n_chunks, C, size]
158+
x = x.reshape(-1, x.size(2), x.size(3))
155159

156160
# pad right for dynamic conv
157161
x = nn.functional.pad(x, (0, self.lorder), 'constant', 0.0)
@@ -161,19 +165,22 @@ def forward(
161165
x = self.depthwise_conv(x)
162166

163167
if self.dynamic_conv:
164-
# x size: [B * n_chunk, C, decoding_chunk_size]
165-
x = x.reshape(-1, n_chunks, x.size(1), x.size(2)) # [B, n_chunk, C, decoding_chunk_size]
166-
x = x.transpose(1, 2) # [B, C, n_chunks, decoding_chunk_size]
167-
x = x.reshape(x.size(0), x.size(1), -1) # [B, C, n_chunks * decoding_chunk_size]
168-
x = x[..., :x.size(2) - n_frames_pad] # (batch, channel, dim)
168+
# [B, n_chunk, C, decoding_chunk_size]
169+
x = x.reshape(-1, n_chunks, x.size(1), x.size(2))
170+
# [B, C, n_chunks, decoding_chunk_size]
171+
x = x.transpose(1, 2)
172+
# [B, C, n_chunks * decoding_chunk_size]
173+
x = x.reshape(x.size(0), x.size(1), -1)
174+
# remove padding
175+
x = x[..., :x.size(2) - n_frames_pad]
169176

170177
if self.use_layer_norm:
171178
x = x.transpose(1, 2)
172179
x = self.activation(self.norm(x))
173180
if self.use_layer_norm:
174181
x = x.transpose(1, 2)
175182
x = self.pointwise_conv2(x)
176-
183+
177184
# mask batch padding
178185
if mask_pad.size(2) > 0: # time > 0
179186
x.masked_fill_(~mask_pad.to(torch.bool), 0.0)
@@ -202,7 +209,7 @@ def forward_parallel_chunk(
202209
"""
203210
# exchange the temporal dimension and the feature dimension
204211
x = x.transpose(1, 2) # (#batch, channels, time)
205-
lorder = self.kernel_size//2
212+
lorder = self.kernel_size // 2
206213
chunk_size = x.shape[-1]
207214
cache_t = cache.size(-1)
208215
if cache_t == 0:
@@ -211,19 +218,21 @@ def forward_parallel_chunk(
211218
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
212219
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
213220

214-
#----------Overlapping Chunk Transformation-----------------------------------
215-
x = x.transpose(0, 1).reshape( self.channels, -1) # [C, n_chunk * T]
221+
# ----------Overlapping Chunk Transformation-----------------------------------
222+
x = x.transpose(0, 1).reshape(self.channels, -1) # [C, n_chunk * T]
216223
x = torch.cat([cache, x], dim=-1)
217224

218-
# streaming long-form transcription is disabled if input cache is empty, only support long-form transcription and masked batch
225+
# Streaming long-form transcription is disabled if input cache is empty
219226
if cache_t > 0:
220-
new_cache = x[:, :truncated_context_size + cache.size(-1)][:, -cache.size(-1):]
227+
new_cache = x[:, :truncated_context_size + cache.size(-1)]
228+
new_cache = new_cache[:, -cache.size(-1):]
221229
else:
222230
new_cache = torch.zeros((0, 0))
223231

224232
x = nn.functional.pad(x, (0, lorder), 'constant', 0.0)
225-
x = x.unfold(-1, chunk_size + 2 * lorder, chunk_size).transpose(0, 1) #[n_chunk +1, C, chunk_size + 2 * lorder]
226-
#-----------------------------------------------------------------------------
233+
x = x.unfold(-1, chunk_size + 2 * lorder, chunk_size).transpose(0, 1)
234+
# [n_chunk +1, C, chunk_size + 2 * lorder]
235+
# -----------------------------------------------------------------------------
227236

228237
if mask_pad.size(2) > 0: # time > 0
229238
x = torch.where(mask_pad, x, 0)

wenet/chunkformer/embedding.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,10 @@
44
from typing import Tuple, Union
55

66
import torch
7-
import torch.nn.functional as F
87

98
class RelPositionalEncodingWithRightContext(torch.nn.Module):
109
"""Relative positional encoding module.
1110
12-
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
13-
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
14-
1511
Args:
1612
d_model: Embedding dimension.
1713
dropout_rate: Dropout rate.
@@ -50,14 +46,19 @@ def extend_pe(self, size: int, left_context: Union[int, torch.Tensor] = 0) -> No
5046

5147
# Reserve the order of positive indices and concat both positive and
5248
# negative indices. This is used to support the shifting trick
53-
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
49+
# as in "Transformer-XL: Attentive Language Models Beyond a
50+
# Fixed-Length Context"
5451
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
5552
pe_negative = pe_negative[1:].unsqueeze(0)
5653
self.pe = torch.cat([pe_positive, pe_negative], dim=1)
5754

58-
def position_encoding(self, offset: Union[int, torch.Tensor], size: int,
59-
apply_dropout: bool = False,
60-
right_context_size: Union[int, torch.Tensor] = 0) -> torch.Tensor:
55+
def position_encoding(
56+
self,
57+
offset: Union[int, torch.Tensor],
58+
size: int,
59+
apply_dropout: bool = False,
60+
right_context_size: Union[int, torch.Tensor] = 0
61+
) -> torch.Tensor:
6162

6263
if isinstance(offset, int):
6364
assert offset + size < self.max_len
@@ -102,5 +103,7 @@ def forward(
102103
103104
"""
104105
x = x * self.xscale
105-
pos_emb = self.position_encoding(offset, x.size(1), False, right_context_size).to(device=x.device, dtype=x.dtype)
106-
return self.dropout(x), self.dropout(pos_emb)
106+
pos_emb = self.position_encoding(
107+
offset, x.size(1), False,
108+
right_context_size).to(device=x.device, dtype=x.dtype)
109+
return self.dropout(x), self.dropout(pos_emb)

0 commit comments

Comments
 (0)