Skip to content

Commit 1a530bb

Browse files
authored
[touch_asu] save in json format and compute wer in json to avoid tric… (#47)
* [touch_asu] save in json format and compute wer in json to avoid tricky problems * fix lint * fix lint in decode.py
1 parent d8fab31 commit 1a530bb

File tree

4 files changed

+47
-43
lines changed

4 files changed

+47
-43
lines changed
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
{
2-
"bos_token_id": 151643,
3-
"do_sample": true,
4-
"eos_token_id": 151643,
2+
"do_sample": false,
53
"max_new_tokens": 50,
64
"transformers_version": "4.37.0"
75
}

examples/aishell/asr/run.sh

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,8 @@ if [ $stage == "decode" ] || [ $stage == "all" ]; then
5757
cp conf/generation_config.json $mdir
5858
python west/bin/decode.py \
5959
--data_path $data/test.jsonl \
60-
--model_dir $PWD/$mdir \
61-
--result_path $mdir/result.txt
62-
paste <(awk '{print $1}' $data/test.text) $mdir/result.txt > $mdir/result.hyp
63-
python tools/compute-wer.py --char=1 --v=1 \
64-
$data/test.text $mdir/result.hyp > $mdir/result.wer
60+
--model_dir $mdir \
61+
--result_path $mdir/result.jsonl
62+
python tools/compute_wer.py --char=1 --v=1 \
63+
$data/test.jsonl $mdir/result.jsonl > $mdir/result.wer
6564
fi

tools/compute-wer.py renamed to tools/compute_wer.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
33

4-
import re, sys, unicodedata
54
import codecs
5+
import json
6+
import sys
7+
import unicodedata
68

79
remove_tag = True
810
spacelist = [' ', '\t', '\r', '\n']
@@ -21,17 +23,20 @@ def characterize(string):
2123
i += 1
2224
continue
2325
cat1 = unicodedata.category(char)
24-
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
25-
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
26+
# https://unicodebook.readthedocs.io/unicode.html#unicode-categories
27+
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist:
28+
# space or not assigned
2629
i += 1
2730
continue
2831
if cat1 == 'Lo': # letter-other
2932
res.append(char)
3033
i += 1
3134
else:
32-
# some input looks like: <unk><noise>, we want to separate it to two words.
35+
# some input looks like: <unk><noise>, we want to separate it to
36+
# two words.
3337
sep = ' '
34-
if char == '<': sep = '>'
38+
if char == '<':
39+
sep = '>'
3540
j = i + 1
3641
while j < len(string):
3742
c = string[j]
@@ -46,7 +51,8 @@ def characterize(string):
4651

4752

4853
def stripoff_tags(x):
49-
if not x: return ''
54+
if not x:
55+
return ''
5056
chars = []
5157
i = 0
5258
T = len(x)
@@ -210,9 +216,9 @@ def calculate(self, lab, rec):
210216
elif self.space[i][j]['error'] == 'non': # starting point
211217
break
212218
else: # shouldn't reach here
213-
print(
214-
'this should not happen , i = {i} , j = {j} , error = {error}'
215-
.format(i=i, j=j, error=self.space[i][j]['error']))
219+
print('this should not happen '
220+
'i = {i} , j = {j} , error = {error}'.format(
221+
i=i, j=j, error=self.space[i][j]['error']))
216222
return result
217223

218224
def overall(self):
@@ -286,10 +292,10 @@ def default_cluster(word):
286292

287293
def usage():
288294
print(
289-
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
295+
"compute-wer.py : compute word error rate (WER) and align recognition results and references." # noqa
290296
)
291297
print(
292-
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
298+
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" # noqa
293299
)
294300

295301

@@ -364,7 +370,7 @@ def usage():
364370
verbose = 0
365371
try:
366372
verbose = int(b)
367-
except:
373+
except Exception:
368374
if b == 'true' or b != '0':
369375
verbose = 1
370376
continue
@@ -378,7 +384,7 @@ def usage():
378384
padding_symbol = '_'
379385
continue
380386
if True or sys.argv[1].startswith('-'):
381-
#ignore invalid switch
387+
# ignore invalid switch
382388
del sys.argv[1]
383389
continue
384390

@@ -391,7 +397,7 @@ def usage():
391397

392398
ref_file = sys.argv[1]
393399
hyp_file = sys.argv[2]
394-
rec_set = {}
400+
rec_list = []
395401
if split and not case_sensitive:
396402
newsplit = dict()
397403
for w in split:
@@ -401,29 +407,30 @@ def usage():
401407
newsplit[w.upper()] = words
402408
split = newsplit
403409

404-
with codecs.open(hyp_file, 'r', 'utf-8') as fh:
410+
with open(hyp_file) as fh:
405411
for line in fh:
412+
item = json.loads(line)
413+
assert 'txt' in item
414+
line = item['txt']
406415
if tochar:
407416
array = characterize(line)
408417
else:
409418
array = line.strip().split()
410-
if len(array) == 0: continue
411-
fid = array[0]
412-
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
413-
split)
419+
rec_list.append(
420+
normalize(array, ignore_words, case_sensitive, split))
414421

415422
# compute error rate on the interaction of reference file and hyp file
416-
for line in open(ref_file, 'r', encoding='utf-8'):
423+
for i, line in enumerate(open(ref_file, 'r', encoding='utf-8')):
424+
item = json.loads(line)
425+
assert 'txt' in item
426+
line = item['txt']
427+
fid = item['wav']
417428
if tochar:
418429
array = characterize(line)
419430
else:
420431
array = line.rstrip('\n').split()
421-
if len(array) == 0: continue
422-
fid = array[0]
423-
if fid not in rec_set:
424-
continue
425-
lab = normalize(array[1:], ignore_words, case_sensitive, split)
426-
rec = rec_set[fid]
432+
lab = normalize(line, ignore_words, case_sensitive, split)
433+
rec = rec_list[i]
427434
if verbose:
428435
print('\nutt: %s' % fid)
429436

@@ -489,8 +496,7 @@ def usage():
489496

490497
if verbose:
491498
print(
492-
'==========================================================================='
493-
)
499+
'================================================================')
494500
print()
495501

496502
result = calculator.overall()
@@ -525,8 +531,8 @@ def usage():
525531
for line in open(cluster_file, 'r', encoding='utf-8'):
526532
for token in line.decode('utf-8').rstrip('\n').split():
527533
# end of cluster reached, like </Keyword>
528-
if token[0:2] == '</' and token[len(token)-1] == '>' and \
529-
token.lstrip('</').rstrip('>') == cluster_id :
534+
if (token[0:2] == '</' and token[len(token) - 1] == '>'
535+
and token.lstrip('</').rstrip('>') == cluster_id):
530536
result = calculator.cluster(cluster)
531537
if result['all'] != 0:
532538
wer = float(result['ins'] + result['sub'] +
@@ -540,14 +546,13 @@ def usage():
540546
cluster_id = ''
541547
cluster = []
542548
# begin of cluster reached, like <Keyword>
543-
elif token[0] == '<' and token[len(token)-1] == '>' and \
544-
cluster_id == '' :
549+
elif (token[0] == '<' and token[len(token) - 1] == '>'
550+
and cluster_id == ''):
545551
cluster_id = token.lstrip('<').rstrip('>')
546552
cluster = []
547553
# general terms, like WEATHER / CAR / ...
548554
else:
549555
cluster.append(token)
550556
print()
551557
print(
552-
'==========================================================================='
553-
)
558+
'================================================================')

west/bin/decode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) 2025 Binbin Zhang([email protected])
2+
import json
23
import sys
34
from dataclasses import dataclass, field
45

@@ -42,7 +43,8 @@ def main():
4243
print(text)
4344
for t in text:
4445
t = t.replace('\n', ' ')
45-
fid.write(t + '\n')
46+
item = {'txt': t}
47+
fid.write(json.dumps(item, ensure_ascii=False) + '\n')
4648
sys.stdout.flush()
4749
fid.close()
4850

0 commit comments

Comments
 (0)