Skip to content

Commit 1a8ee65

Browse files
mlxu995menglong.xu
andauthored
[fix] fix hey_snips training script(#195) (#204)
Co-authored-by: menglong.xu <[email protected]>
1 parent 0eeec75 commit 1a8ee65

File tree

4 files changed

+35
-17
lines changed

4 files changed

+35
-17
lines changed

examples/hey_snips/s0/conf/ds_tcn.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@ dataset_conf:
22
filter_conf:
33
max_length: 2048
44
min_length: 0
5+
token_max_length: 200
6+
token_min_length: 1
7+
max_output_input_ratio: 1
8+
min_output_input_ratio: 0.0005
59
resample_conf:
610
resample_rate: 16000
711
speed_perturb: false
812
reverb_prob: 0.2
913
noise_prob: 0.3
10-
feature_extraction_conf:
11-
feature_type: 'fbank'
14+
feats_type: 'fbank'
15+
fbank_conf:
1216
num_mel_bins: 40
1317
frame_shift: 10
1418
frame_length: 25
@@ -22,6 +26,7 @@ dataset_conf:
2226
shuffle: true
2327
shuffle_conf:
2428
shuffle_size: 1500
29+
sort: false
2530
batch_conf:
2631
batch_size: 256
2732

examples/hey_snips/s0/local/prepare_data.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,31 @@ def main():
1919
type=str,
2020
help='dir containing all the wav files')
2121
parser.add_argument('path', type=str, help='path to the json file')
22+
parser.add_argument('dict', type=str, help='path to the dict file')
2223
parser.add_argument('out_dir', type=str, help='out dir')
2324
args = parser.parse_args()
2425

26+
id2token = {}
27+
with open(args.dict, 'r', encoding='utf-8') as f:
28+
for line in f:
29+
token, idx = line.strip().split()
30+
id2token[int(idx)] = token
31+
2532
with open(args.path, 'r', encoding='utf-8') as f:
2633
data = json.load(f)
27-
utt_id, label = [], []
34+
utt_id, text = [], []
2835
for entry in data:
2936
if entry['duration'] > 0:
3037
utt_id.append(entry['id'])
3138
keyword_id = 0 if entry['is_hotword'] == 1 else -1
32-
label.append(keyword_id)
39+
text.append(id2token[keyword_id])
3340

3441
abs_dir = os.path.abspath(args.wav_dir)
3542
wav_path = os.path.join(args.out_dir, 'wav.scp')
3643
text_path = os.path.join(args.out_dir, 'text')
3744
with open(wav_path, 'w', encoding='utf-8') as f_wav, \
3845
open(text_path, 'w', encoding='utf-8') as f_text:
39-
for utt, l in zip(utt_id, label):
46+
for utt, l in zip(utt_id, text):
4047
f_wav.write('{} {}\n'.format(utt,
4148
os.path.join(abs_dir, utt + ".wav")))
4249
f_text.write('{} {}\n'.format(utt, l))

examples/hey_snips/s0/run.sh

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
. ./path.sh
66

7-
stage=0
8-
stop_stage=4
7+
set -euo pipefail
8+
9+
stage=$1
10+
stop_stage=$2
911
num_keywords=1
1012

1113
config=conf/ds_tcn.yaml
@@ -24,8 +26,7 @@ noise_lmdb=
2426
reverb_lmdb=
2527

2628
. tools/parse_options.sh || exit 1;
27-
28-
set -euo pipefail
29+
window_shift=50
2930

3031
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
3132
echo "Extracte all datasets"
@@ -36,14 +37,15 @@ fi
3637
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
3738
echo "Preparing datasets..."
3839
mkdir -p dict
39-
echo "<filler> -1" > dict/words.txt
40-
echo "Hey_Snips 0" >> dict/words.txt
40+
echo "<FILLER> -1" > dict/dict.txt
41+
echo "<HEY_SNIPS> 0" >> dict/dict.txt
42+
awk '{print $1}' dict/dict.txt > dict/words.txt
4143

4244
for folder in train dev test; do
4345
mkdir -p data/$folder
4446
json_path=$download_dir/hey_snips_research_6k_en_train_eval_clean_ter/$folder.json
4547
local/prepare_data.py $download_dir/hey_snips_research_6k_en_train_eval_clean_ter/audio_files $json_path \
46-
data/$folder
48+
dict/dict.txt data/$folder
4749
done
4850
fi
4951

@@ -78,7 +80,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
7880
--num_workers 8 \
7981
--num_keywords $num_keywords \
8082
--min_duration 50 \
81-
--seed 777 \
83+
--seed 666 \
84+
--dict ./dict \
8285
$cmvn_opts \
8386
${reverb_lmdb:+--reverb_lmdb $reverb_lmdb} \
8487
${noise_lmdb:+--noise_lmdb $noise_lmdb} \
@@ -97,21 +100,23 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
97100
python wekws/bin/score.py \
98101
--config $dir/config.yaml \
99102
--test_data data/test/data.list \
103+
--gpu 0 \
100104
--batch_size 256 \
101105
--checkpoint $score_checkpoint \
102106
--score_file $result_dir/score.txt \
107+
--dict ./dict \
103108
--num_workers 8
104-
first_keyword=0
105-
last_keyword=$(($num_keywords+$first_keyword-1))
106-
for keyword in $(seq $first_keyword $last_keyword); do
109+
110+
for keyword in `tail -n +2 dict/words.txt`; do
107111
python wekws/bin/compute_det.py \
108112
--keyword $keyword \
109113
--test_data data/test/data.list \
114+
--window_shift $window_shift \
110115
--score_file $result_dir/score.txt \
111116
--stats_file $result_dir/stats.${keyword}.txt
112117
done
113118
python wekws/bin/plot_det_curve.py \
114-
--keywords_dict dict/words.txt \
119+
--keywords_dict dict/dict.txt \
115120
--stats_dir $result_dir \
116121
--figure_file $result_dir/det.png \
117122
--xlim 10 \

wekws/bin/score.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def main():
8080
test_conf = copy.deepcopy(configs['dataset_conf'])
8181
test_conf['filter_conf']['max_length'] = 102400
8282
test_conf['filter_conf']['min_length'] = 0
83+
test_conf['filter_conf']['min_output_input_ratio'] = 0
8384
test_conf['speed_perturb'] = False
8485
test_conf['spec_aug'] = False
8586
test_conf['shuffle'] = False

0 commit comments

Comments
 (0)