Skip to content

Commit 82c34c4

Browse files
committed
add missing files
1 parent a6d02e2 commit 82c34c4

File tree

5 files changed

+172
-5
lines changed

5 files changed

+172
-5
lines changed

wenet/firered/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def position_encoding(self,
4949

5050
raise NotImplementedError('firedasr not support streaming pos encding')
5151

52-
def forward(self, x):
52+
def forward(self, x, offset=None):
5353
Tmax, T = self.pe.size(1), x.size(1)
5454
pos_emb = self.pe[:, Tmax // 2 - T + 1:Tmax // 2 + T].clone().detach()
5555
return self.dropout(x), self.dropout(pos_emb)
@@ -99,7 +99,7 @@ def rel_shift(self, x):
9999
x.size()[1],
100100
x.size(3) + 1, x.size(2))
101101
x = x_padded[:, :, 1:].view_as(x)
102-
x = x[:, :, :, :, x.size(-1) // 2 + 1]
102+
x = x[:, :, :, :x.size(-1) // 2 + 1]
103103

104104
return x
105105

wenet/firered/convert_FireRed_AED_L_to_wenet_config_and_ckpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def convert_to_wenet_yaml(tokenizer: BaseTokenizer, dims, wenet_yaml_path: str,
9292
configs['ctc_conf'] = {}
9393
configs['ctc_conf']['ctc_blank_id'] = 0
9494

95-
configs['cmvn'] = None
95+
configs['cmvn'] = 'global_cmvn'
9696
configs['cmvn_conf'] = {}
9797
configs['cmvn_conf']['cmvn_file'] = json_cmvn_path
9898
configs['cmvn_conf']['is_json_cmvn'] = True

wenet/firered/encoder.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from typing import Optional
2+
3+
import torch
4+
from wenet.firered.encoder_layer import FireRedConformerEncoderLayer
5+
from wenet.transformer.convolution import ConvolutionModule
6+
from wenet.transformer.encoder import BaseEncoder
7+
from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES,
8+
WENET_ATTENTION_CLASSES,
9+
WENET_MLP_CLASSES)
10+
11+
12+
class FireRedConformerEncoder(BaseEncoder):
13+
"""Conformer encoder module."""
14+
15+
def __init__(
16+
self,
17+
input_size: int,
18+
output_size: int = 256,
19+
attention_heads: int = 4,
20+
linear_units: int = 2048,
21+
num_blocks: int = 6,
22+
dropout_rate: float = 0.1,
23+
positional_dropout_rate: float = 0.1,
24+
attention_dropout_rate: float = 0.0,
25+
input_layer: str = "conv2d",
26+
pos_enc_layer_type: str = "rel_pos",
27+
normalize_before: bool = True,
28+
static_chunk_size: int = 0,
29+
use_dynamic_chunk: bool = False,
30+
global_cmvn: torch.nn.Module = None,
31+
use_dynamic_left_chunk: bool = False,
32+
positionwise_conv_kernel_size: int = 1,
33+
macaron_style: bool = True,
34+
selfattention_layer_type: str = "rel_selfattn",
35+
activation_type: str = "swish",
36+
use_cnn_module: bool = True,
37+
cnn_module_kernel: int = 15,
38+
causal: bool = False,
39+
cnn_module_norm: str = "batch_norm",
40+
query_bias: bool = True,
41+
key_bias: bool = True,
42+
value_bias: bool = True,
43+
conv_bias: bool = True,
44+
gradient_checkpointing: bool = False,
45+
use_sdpa: bool = False,
46+
layer_norm_type: str = 'layer_norm',
47+
norm_eps: float = 1e-5,
48+
n_kv_head: Optional[int] = None,
49+
head_dim: Optional[int] = None,
50+
mlp_type: str = 'position_wise_feed_forward',
51+
mlp_bias: bool = True,
52+
n_expert: int = 8,
53+
n_expert_activated: int = 2,
54+
conv_norm_eps: float = 1e-5,
55+
conv_inner_factor: int = 2,
56+
final_norm: bool = True,
57+
):
58+
"""ConstruConformerEncoder
59+
60+
Args:
61+
input_size to use_dynamic_chunk, see in BaseEncoder
62+
positionwise_conv_kernel_size (int): Kernel size of positionwise
63+
conv1d layer.
64+
macaron_style (bool): Whether to use macaron style for
65+
positionwise layer.
66+
selfattention_layer_type (str): Encoder attention layer type,
67+
the parameter has no effect now, it's just for configure
68+
compatibility.
69+
activation_type (str): Encoder activation function type.
70+
use_cnn_module (bool): Whether to use convolution module.
71+
cnn_module_kernel (int): Kernel size of convolution module.
72+
causal (bool): whether to use causal convolution or not.
73+
key_bias: whether use bias in attention.linear_k, False for whisper models.
74+
"""
75+
super().__init__(input_size, output_size, attention_heads,
76+
linear_units, num_blocks, dropout_rate,
77+
positional_dropout_rate, attention_dropout_rate,
78+
input_layer, pos_enc_layer_type, normalize_before,
79+
static_chunk_size, use_dynamic_chunk, global_cmvn,
80+
use_dynamic_left_chunk, gradient_checkpointing,
81+
use_sdpa, layer_norm_type, norm_eps, final_norm)
82+
activation = WENET_ACTIVATION_CLASSES[activation_type]()
83+
84+
# self-attention module definition
85+
encoder_selfattn_layer_args = (
86+
attention_heads,
87+
output_size,
88+
attention_dropout_rate,
89+
query_bias,
90+
key_bias,
91+
value_bias,
92+
use_sdpa,
93+
n_kv_head,
94+
head_dim,
95+
)
96+
# feed-forward module definition
97+
positionwise_layer_args = (
98+
output_size,
99+
linear_units,
100+
dropout_rate,
101+
activation,
102+
mlp_bias,
103+
n_expert,
104+
n_expert_activated,
105+
)
106+
# convolution module definition
107+
convolution_layer_args = (output_size, cnn_module_kernel, activation,
108+
cnn_module_norm, causal, conv_bias,
109+
conv_norm_eps, conv_inner_factor)
110+
111+
mlp_class = WENET_MLP_CLASSES[mlp_type]
112+
113+
self.encoders = torch.nn.ModuleList([
114+
FireRedConformerEncoderLayer(
115+
output_size,
116+
WENET_ATTENTION_CLASSES[selfattention_layer_type](
117+
*encoder_selfattn_layer_args),
118+
mlp_class(*positionwise_layer_args),
119+
mlp_class(*positionwise_layer_args) if macaron_style else None,
120+
ConvolutionModule(
121+
*convolution_layer_args) if use_cnn_module else None,
122+
dropout_rate,
123+
normalize_before,
124+
layer_norm_type=layer_norm_type,
125+
norm_eps=norm_eps,
126+
) for _ in range(num_blocks)
127+
])

wenet/firered/encoder_layer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch import nn
5+
from wenet.transformer.encoder_layer import ConformerEncoderLayer
6+
7+
8+
class FireRedConformerEncoderLayer(ConformerEncoderLayer):
9+
"""Encoder layer module.
10+
Args:
11+
size (int): Input dimension.
12+
self_attn (torch.nn.Module): Self-attention module instance.
13+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
14+
instance can be used as the argument.
15+
feed_forward (torch.nn.Module): Feed-forward module instance.
16+
`PositionwiseFeedForward` instance can be used as the argument.
17+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
18+
instance.
19+
`PositionwiseFeedForward` instance can be used as the argument.
20+
conv_module (torch.nn.Module): Convolution module instance.
21+
`ConvlutionModule` instance can be used as the argument.
22+
dropout_rate (float): Dropout rate.
23+
normalize_before (bool):
24+
True: use layer_norm before each sub-block.
25+
False: use layer_norm after each sub-block.
26+
"""
27+
28+
def __init__(self,
29+
size: int,
30+
self_attn: torch.nn.Module,
31+
feed_forward: Optional[nn.Module] = None,
32+
feed_forward_macaron: Optional[nn.Module] = None,
33+
conv_module: Optional[nn.Module] = None,
34+
dropout_rate: float = 0.1,
35+
normalize_before: bool = True,
36+
layer_norm_type: str = 'layer_norm',
37+
norm_eps: float = 0.00001):
38+
super().__init__(size, self_attn, feed_forward, feed_forward_macaron,
39+
conv_module, dropout_rate, normalize_before,
40+
layer_norm_type, norm_eps)
41+
del self.norm_mha
42+
self.norm_mha = torch.nn.Identity()

wenet/firered/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ def __init__(
5858

5959
# fix final norm in conformer
6060
del self.encoder.after_norm
61-
# fix output bias
62-
del self.decoder.output_layer.bias
6361

6462
@torch.jit.unused
6563
def forward_encoder_chunk(

0 commit comments

Comments
 (0)