-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[wenet] nn context biasing #1982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…tion problem due to context mismatch.
wenet/transformer/context_module.py
Outdated
| _, last_state = self.sen_rnn(pack_seq) | ||
| laste_h = last_state[0] | ||
| laste_c = last_state[1] | ||
| state = torch.cat([laste_h[-1, :, :], laste_h[0, :, :], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi,这里的实现是最后一层BLSTM的reverse last_h_state和第一层的forward last_h_state?
torch.nn.LSTM
**h_n**: tensor of shape :math:(D * \text{num_layers}, H_{out}) for unbatched input or :math:(D * \text{num_layers}, N, H_{out})containing the final hidden state for each element in the sequence. When ``bidirectional=True``,h_n will contain a concatenation of the final forward and reverse hidden states, respectively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是我写错了,0应该改成-2,感谢指正
…hunk during bias module training
| for utt_label in batch_label: | ||
| st_index_list = [] | ||
| for i in range(len(utt_label)): | ||
| if '▁' not in symbol_table: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我想请问下,这里如果我的建模单元是中文汉字+英文bpe,这里是不是不太适用,需要改下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,我自己训练的时候都是纯中文或者纯英文,英文在热词采样的时候对下划线特殊处理了下保证不会采样出半个词的情况,如果同时有中文和英文这部分最好是改下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
谢谢~
|
可以提供一些模型训练时候的conf.yaml参数设置吗?谢谢 |
上面的模型链接中有我用的yaml文件,可以直接下载 |
|
我想请问下,我在aishell170小时上训练了deep biasing的模型,但是在解码的时候如果设置deep biasing,会出现很多的漏字现象,这个会是什么原因呀? |
漏字的现象很严重吗,使用的热词列表大小多大?我这边也有做过aishell1的实验,结果比较正常,没有观察到漏字的现象 |
很严重,就是一段一段的漏,原始设置的热词表大小是187,modelscope上开源的热词测试集,然后是设置了context_filtering参数进行过滤,如果过滤后热词表只有【0】的话,基本上就整句话漏了,如果是有热词的情况,也会出现成片漏掉的情况,设置的deep_score=1,filter_threshold=-4。目前训练迭代了17个epoch,loss_bias在10左右 |
那确实很奇怪,总体loss的情况正常吗,正常情况下收敛到差不多的时候,bias loss应该是和ctc loss差不多,总体的loss应该会比没有训练热词模块之前更低一些,在aishell上大概是3.4左右。你用的热词相关的yaml配置是否都和我上面给出的一致 |
还有就是我在做aishell1实验的时候发现对于aishell1这种句子大部分都很短的数据集,热词采样的代码需要去掉那个判断采样热词不能交叉的逻辑,不然很容易一句话只能采样出一个热词,这样训出来热词增强的效果会差一些,不过这个问题并不会导致漏字的情况。 |
目前训练出来整体的loss还算是正常,从3.1下降到了2.5,bias loss会比ctc loss高一些。我现在的热词配置就是您给的这个哈 |
会不会是你修改的热词采样部分的代码有点问题,我这边确实没遇到过你描述的状况,也想不出是什么原因,漏字而且还和传入的热词数量有关,理论上来说热词列表只剩个0应该对于正常解码的影响是最小的 |
|
您好,我尝试复现您在librispeech的结果,但是在训练热词增强模型时,出现cv loss值不下降的情况(保持在160多),并且train loss也是下降到四五十就不太下降了。 另外,我发现每次训练几个batch时,都会花五六分钟去训练下一个batch,正常情况我的显卡每训练一个batch的时间是30s左右,下面是一小段训练日志。。。 我没修改任何代码,训练conf文件也是您提供那个train_bias, 能大概分析下出现问题的原因吗? 谢谢! 2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 77.121582 loss_att 68.588936 loss_ctc 90.912209 loss_bias 61.188702 lr 0.00001204 rank 3 |
你是不是直接从头开始训练了,为了减少对原本asr性能的影响,我写的是从一个预训练好的asr模型开始训,除了热词模块之外的参数都给冻结了。从头开始训应该也能够收敛,但是至少得把冻结的参数先解冻。 |
没有,也是用的之前在librispeech上预训练好的asr模型,做了参数冻结 |
是不是没有用全量数据微调,我之前试过小数据量会引入通用效果损失 |
我也是在aishell170小时数据集上训练了30轮进行测试的 |
初始模型也是aishell训练的嘛,把deep biasing的权重调小看看 |
|
您好,我使用wenetspeech预训练的模型进行微调,但是wenetspeech数据集比较大我没有使用全量数据,只下载了大概40G的数据,目前训练了60轮,但是感觉热词增强的效果一般,较训练前没有什么提升,是训练的不够吗还是其他原因,下面截取了一小部分训练日志,可以帮忙看一下损失值都正常吗,cv_loss也基本在12左右波动,目前还在继续训练观察。 |
全量数据微调,轮次不需要太大,越靠前通用效果损失越小,但是热词效果越差,我一般都是取前七轮做个avg的,你可以拿靠前的轮次测测看热词的效果,测的时候尝试对比下不同热词权重的差异 |
好的,谢谢回复,我去试试 |
|
请问有没有遇到过导出整个识别+偏置模型的onnx时,onnx图中LSTM层输入的热词表大小会固定不变的问题 |
建议热词模块拆出来导,这样还能节省推理资源,lstm那块需要设置动态维度,而且建议用单向lstm,通过热词列表长度索引状态 |
好的,谢谢您的建议 :) |
|
大佬您好,我目前想基于你的工作给FireRedASR加个热词模块,我想咨询一些模型训练上的事情。首先FireRedASR有1.1B参数量,不知道这个量级的模型的热词模块是否应该也更大?此外FireRedASR的模型训练数据集不开源,如果我用其他的数据集训练例如aishell和librispeech是否可行? |
|
您好,您的邮件我已经收到,会尽快给您回复!祝您生活愉快,工作顺利!
|
|
@kaixunhuang0 凯勋您好, 我目前的操作是这样的:
想请教您两点:
如果方便的话,能否麻烦您提供一下当时的推理命令示例(包含 model、config、data 路径和解码参数),我再对照跑一次看看问题出在哪一环。谢谢! 🙏 |
|
您好,您的邮件我已经收到,会尽快给您回复!祝您生活愉快,工作顺利!
|
我测试了下,直接用git fetch origin pull/1982/head:pr-1982,然后切换到pr-1982分支上,用hf模型推理libri test other是能得到正常的结果的。你可以检查下
|
感谢回复! 1. 推理时的日志 头几行如下 可以看到是能解出 token 的,但有非常多的重复词例如TALENT。 2. 我用的字典 我用的是仓库里这个shell脚本生成的字典: 开头和结尾是这样的: 一共 5002 行。 3. 推理部分的代码 examples/librispeech/s0/run.sh 下面是我在 if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# Test model, please specify the model you want to test by --checkpoint
cmvn_opts=
$cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn"
decode_checkpoint=$dir/final.pt
# mkdir -p $dir/test
# if [ ${average_checkpoint} == true ]; then
# decode_checkpoint=$dir/avg_${average_num}.pt
# echo "do model average and final checkpoint is $decode_checkpoint"
# python wenet/bin/average_model.py \
# --dst_model $decode_checkpoint \
# --src_path $dir \
# --num ${average_num} \
# --val_best
# fi
# Specify decoding_chunk_size if it's a unified dynamic chunk trained model
# -1 for full chunk
decoding_chunk_size=
ctc_weight=0.5
# Polling GPU id begin with index 0
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
idx=0
for test in $recog_set; do
for mode in ${decode_modes}; do
{
{
test_dir=$dir/${test}_${mode}
mkdir -p $test_dir
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$idx+1])
python wenet/bin/recognize.py --gpu $gpu_id \
--mode $mode \
--config $dir/train.yaml \
--data_type raw \
--dict $dict \
--bpe_model ${bpemodel}.model \
--test_data $wave_data/$test/data.list \
--checkpoint $decode_checkpoint \
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--result_file $test_dir/text_bpe \
--ctc_weight $ctc_weight \
${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} \
--context_bias_mode deep_biasing \
--context_list_path exp/models--kxhuang--Wenet_Librispeech_deep_biasing/snapshots/1025d306fe918d175fc8cccb4de18fc42bc3170b/nn_bias/bias_model/test_other_context_list \
--deep_biasing_score 1.0
cut -f2- -d " " $test_dir/text_bpe > $test_dir/text_bpe_value_tmp
cut -f1 -d " " $test_dir/text_bpe > $test_dir/text_bpe_key_tmp
tools/spm_decode --model=${bpemodel}.model --input_format=piece \
< $test_dir/text_bpe_value_tmp | sed -e "s/▁/ /g" > $test_dir/text_value_tmp
paste -d " " $test_dir/text_bpe_key_tmp $test_dir/text_value_tmp > $test_dir/text
python tools/compute-wer.py --char=1 --v=1 \
$wave_data/$test/text $test_dir/wer
} &
((idx+=1))
if [ $idx -eq $num_gpus ]; then
idx=0
fi
}
done
done
wait
fi(以下几行是参照issue #2148 增加的) 想请教的问题:
麻烦您帮忙看一下,非常感谢! 🙏 |
字典对不上,我这边5000对应的是“ZZ”,你切换成hf里的那个units.txt就可以了。train_960_unigram5000.model是分词模型,在run.sh搜索下bpemodel就可以找到在哪里设置了 |
感谢解惑!修改后使用完整的test_other热词列表,在biasing_score=1.0的情况下wer为8.60%,和您的结果一致了,非常感谢您的帮助,期待您的更多成果! 此外我想请问一下关于context_list的问题,训练的context_list是完全来自processor. context_sampling吗?也就是完全来自于batch A的随机采样,batch A 中 当前sample出现的为目标词,其他sample为干扰词,这么理解对吗?是否有试过将测试用的真实context_list加入训练,是否会有性能提升? |
是的,随机采样的目的是为了让热词模块可以见过各种各样的热词组合,有较好的泛化性,这样测试时见到没见过的热词也能有效增强。把真实context list也加入训练的话,你需要让训练集中有包含这些热词的正样例数据,但模型可能会对这些词过拟合。如果你最终使用的就是固定的这些热词的话可能是有用的。 |
|
大佬,我想请教你一个问题,我用你代码,在一样的网络结构下可以复现你的结果了。我现在在尝试在dolphin(类whisper结构)加入cpnn网络微调,微调完,cv loss,bias_loss比ctc loss平均高1到2左右,跑bench mark,热词没有起作用,有可能是模型太大,这个网络结构不适用的原因吗(我应该要加大层数?)还是我的一些超参数需要修改,不能直接复用你之前的超参数 |
这个热词模块结构应该是比较简单泛用的,在各种asr框架都能起作用,而且热词bias不算是需要特别多层网络才能完成的任务。可以试下增加层数,但我觉得层数不够应该只是效果不那么明显而不是完全失效,可以再检查下迁移代码的时候有没有什么遗漏?从loss上看如果训练有效的话,带bias训练收敛后,asr本身的loss应该要比不带bias训练更低一些,因为训练过程中从标签随机采样的热词也能让asr loss减少 |
感谢大佬回复,asr本身的loss确实也比不带bias训练的更低,微调后跑aishell测试集,非热词wer有些下降,总体wer是下降的,但是热词wer基本不变,热词部分没有生效,就比较诡异 |
确实挺诡异的,可以检查下训练过程热词采样啥的有没有正确采样并传给模型,这里可能比较容易出问题 |
大佬,你好。我仔细检查了一下热词采样部分的代码,确实发现有一点bug,修改后,现在在aishell测试集,热词wer也提升比较明显了,比只用graph,效果要好那么一点,和你实验也是比较一致的,但是现在还存在几个问题: 1.当我使用two stage过滤热词数量时候,效果反而会变差很多,很多甚至会乱码,热词数量是400时候,反而是正常的 |
应该是热词过滤代码改的有点有问题吧,解码的时候打印下看看筛选出来的热词是不是合理的。可能都不是筛的不准的问题,因为哪怕丢完全无关的一堆热词去做热词增强也不应该会出现乱码,得检查下热词筛选后的bias过程是不是按照之前的逻辑正常进行bias的,有没有哪个编码加错位置了之类的 |
大佬,我仔细check过了,没有错,filter算法选出来的token是合理的,另外我去掉filter算法,自己只输入一个热词,他识别的结果也会有问题,会带偏其他正确的结果,具体你看我下面的测试用例,只输入一个热词时候,反而不行,两个热词,一个是无关的反而可以: |
这个我又试了一下,只有一个热词时候,deep_score调到0点几时候,结果就正常了,感觉是热词模块影响过大,容易带偏原来正确的结果,但是实际用的时候,每一句到底用多少的score都是不一定的,这个有什么办法可以改善吗 |
可以调整下训练时热词采样的策略,让训练的过程中模型也见过这种只有很少数热词(可能正确也可能是干扰项)的情况,应该会有改善 |
大佬,commonvoice测试集,效果没有提升的原因,我大概找到了。现在情况是,commonvoice测试集热词wer有下降,但是非热词wer会上升,总体wer是上升的,原本正确可以识别的,被bias_encoder带偏的有点严重。听了一下wer变差的语音,发现多少都带点噪音,aishell训练集比较干净,没有见过,导致bias_encoder直接带偏了原来的模型,加数据应该可以解决。如果有几万小时训练数据,bias encoder的层数和其他超参数,大佬,你觉得需要修改吗 |
那确实挺正常的,aishell1训的模型泛化性肯定不行,加数据应该会好不少。层数和超参数这些感觉没啥好改的,多一两层也许会好一点,但估计是不会有多少区别的 |
一加数据,确实一下就好了,感谢大佬指导 |
The Deep biasing method comes from: https://arxiv.org/abs/2305.12493
The pre-trained ASR model is fine-tuned to achieve biasing. During the training process, the original parameters of the ASR model are frozen, and only the parameters related to deep biasing are trained. use_dynamic_chunk cannot be enabled during fine-tuning (the biasing effect will decrease), but the biasing effects of streaming and non-streaming inference are basically the same.
RESULT:
Model link: https://huggingface.co/kxhuang/Wenet_Librispeech_deep_biasing/tree/main
(I used the BLSTM forward state incorrectly when training this model, so to test this model you need to change the -2 to 0 in the forward function of the BLSTM class in wenet/transformer/context_module.py)
Using the Wenet Librispeech pre-trained AED model, after fine-tuning for 30 epochs, the final model was obtained with an average of 3 epochs. The following are the test results of the Librispeech test other.
The context list for the test set is sourced from: https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias
Non-streaming inference:
+ deep biasing
+ deep biasing
Streaming inference (chunk 16):
+ deep biasing