Skip to content

Commit aa3fe1a

Browse files
committed
python runtime
1 parent 03911c8 commit aa3fe1a

File tree

7 files changed

+253
-70
lines changed

7 files changed

+253
-70
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/usr/bin/env python3
2+
# -*- encoding: utf-8 -*-
3+
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
4+
# MIT License (https://opensource.org/licenses/MIT)
5+
6+
from pathlib import Path
7+
from funasr_onnx import SenseVoiceSmall
8+
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
9+
10+
11+
model_dir = "iic/SenseVoiceSmall"
12+
13+
model = SenseVoiceSmall(model_dir, batch_size=10, quantize=False)
14+
15+
# inference
16+
wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)]
17+
18+
res = model(wav_or_scp, language="auto", use_itn=True)
19+
print([rich_transcription_postprocess(i) for i in res])

runtime/python/onnxruntime/demo_sencevoicesmall.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

runtime/python/onnxruntime/funasr_onnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from .vad_bin import Fsmn_vad_online
55
from .punc_bin import CT_Transformer
66
from .punc_bin import CT_Transformer_VadRealtime
7-
from .sensevoice_bin import SenseVoiceSmallONNX
7+
from .sensevoice_bin import SenseVoiceSmall

runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020
get_logger,
2121
read_yaml,
2222
)
23+
from .utils.sentencepiece_tokenizer import SentencepiecesTokenizer
2324
from .utils.frontend import WavFrontend
2425

2526
logging = get_logger()
2627

2728

28-
class SenseVoiceSmallONNX:
29+
class SenseVoiceSmall:
2930
"""
3031
Author: Speech Lab of DAMO Academy, Alibaba Group
3132
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
@@ -43,45 +44,72 @@ def __init__(
4344
cache_dir: str = None,
4445
**kwargs,
4546
):
47+
48+
if not Path(model_dir).exists():
49+
try:
50+
from modelscope.hub.snapshot_download import snapshot_download
51+
except:
52+
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
53+
try:
54+
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
55+
except:
56+
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
57+
model_dir
58+
)
59+
60+
model_file = os.path.join(model_dir, "model.onnx")
4661
if quantize:
4762
model_file = os.path.join(model_dir, "model_quant.onnx")
48-
else:
49-
model_file = os.path.join(model_dir, "model.onnx")
63+
if not os.path.exists(model_file):
64+
print(".onnx does not exist, begin to export onnx")
65+
try:
66+
from funasr import AutoModel
67+
except:
68+
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
69+
70+
model = AutoModel(model=model_dir)
71+
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
5072

5173
config_file = os.path.join(model_dir, "config.yaml")
5274
cmvn_file = os.path.join(model_dir, "am.mvn")
5375
config = read_yaml(config_file)
54-
# token_list = os.path.join(model_dir, "tokens.json")
55-
# with open(token_list, "r", encoding="utf-8") as f:
56-
# token_list = json.load(f)
5776

58-
# self.converter = TokenIDConverter(token_list)
59-
self.tokenizer = CharTokenizer()
60-
config["frontend_conf"]['cmvn_file'] = cmvn_file
77+
self.tokenizer = SentencepiecesTokenizer(
78+
bpemodel=os.path.join(model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model")
79+
)
80+
config["frontend_conf"]["cmvn_file"] = cmvn_file
6181
self.frontend = WavFrontend(**config["frontend_conf"])
6282
self.ort_infer = OrtInferSession(
6383
model_file, device_id, intra_op_num_threads=intra_op_num_threads
6484
)
6585
self.batch_size = batch_size
6686
self.blank_id = 0
87+
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
88+
self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
89+
self.textnorm_dict = {"withitn": 14, "woitn": 15}
90+
self.textnorm_int_dict = {25016: 14, 25017: 15}
91+
92+
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs):
93+
94+
language = self.lid_dict[kwargs.get("language", "auto")]
95+
use_itn = kwargs.get("use_itn", False)
96+
textnorm = kwargs.get("text_norm", None)
97+
if textnorm is None:
98+
textnorm = "withitn" if use_itn else "woitn"
99+
textnorm = self.textnorm_dict[textnorm]
67100

68-
def __call__(self,
69-
wav_content: Union[str, np.ndarray, List[str]],
70-
language: List,
71-
textnorm: List,
72-
tokenizer=None,
73-
**kwargs) -> List:
74101
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
75102
waveform_nums = len(waveform_list)
76103
asr_res = []
77104
for beg_idx in range(0, waveform_nums, self.batch_size):
78105
end_idx = min(waveform_nums, beg_idx + self.batch_size)
79106
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
80-
ctc_logits, encoder_out_lens = self.infer(feats,
81-
feats_len,
82-
np.array(language, dtype=np.int32),
83-
np.array(textnorm, dtype=np.int32)
84-
)
107+
ctc_logits, encoder_out_lens = self.infer(
108+
feats,
109+
feats_len,
110+
np.array(language, dtype=np.int32),
111+
np.array(textnorm, dtype=np.int32),
112+
)
85113
# back to torch.Tensor
86114
ctc_logits = torch.from_numpy(ctc_logits).float()
87115
# support batch_size=1 only currently
@@ -91,11 +119,9 @@ def __call__(self,
91119

92120
mask = yseq != self.blank_id
93121
token_int = yseq[mask].tolist()
94-
95-
if tokenizer is not None:
96-
asr_res.append(tokenizer.tokens2text(token_int))
97-
else:
98-
asr_res.append(token_int)
122+
123+
asr_res.append(self.tokenizer.encode(token_int))
124+
99125
return asr_res
100126

101127
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
@@ -136,10 +162,12 @@ def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
136162
feats = np.array(feat_res).astype(np.float32)
137163
return feats
138164

139-
def infer(self,
140-
feats: np.ndarray,
141-
feats_len: np.ndarray,
142-
language: np.ndarray,
143-
textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]:
165+
def infer(
166+
self,
167+
feats: np.ndarray,
168+
feats_len: np.ndarray,
169+
language: np.ndarray,
170+
textnorm: np.ndarray,
171+
) -> Tuple[np.ndarray, np.ndarray]:
144172
outputs = self.ort_infer([feats, feats_len, language, textnorm])
145173
return outputs

runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,123 @@ def sentence_postprocess_sentencepiece(words):
296296
real_word_lists.append(ch)
297297
sentence = "".join(word_lists)
298298
return sentence, real_word_lists
299+
300+
301+
emo_dict = {
302+
"<|HAPPY|>": "😊",
303+
"<|SAD|>": "😔",
304+
"<|ANGRY|>": "😡",
305+
"<|NEUTRAL|>": "",
306+
"<|FEARFUL|>": "😰",
307+
"<|DISGUSTED|>": "🤢",
308+
"<|SURPRISED|>": "😮",
309+
}
310+
311+
event_dict = {
312+
"<|BGM|>": "🎼",
313+
"<|Speech|>": "",
314+
"<|Applause|>": "👏",
315+
"<|Laughter|>": "😀",
316+
"<|Cry|>": "😭",
317+
"<|Sneeze|>": "🤧",
318+
"<|Breath|>": "",
319+
"<|Cough|>": "🤧",
320+
}
321+
322+
lang_dict = {
323+
"<|zh|>": "<|lang|>",
324+
"<|en|>": "<|lang|>",
325+
"<|yue|>": "<|lang|>",
326+
"<|ja|>": "<|lang|>",
327+
"<|ko|>": "<|lang|>",
328+
"<|nospeech|>": "<|lang|>",
329+
}
330+
331+
emoji_dict = {
332+
"<|nospeech|><|Event_UNK|>": "❓",
333+
"<|zh|>": "",
334+
"<|en|>": "",
335+
"<|yue|>": "",
336+
"<|ja|>": "",
337+
"<|ko|>": "",
338+
"<|nospeech|>": "",
339+
"<|HAPPY|>": "😊",
340+
"<|SAD|>": "😔",
341+
"<|ANGRY|>": "😡",
342+
"<|NEUTRAL|>": "",
343+
"<|BGM|>": "🎼",
344+
"<|Speech|>": "",
345+
"<|Applause|>": "👏",
346+
"<|Laughter|>": "😀",
347+
"<|FEARFUL|>": "😰",
348+
"<|DISGUSTED|>": "🤢",
349+
"<|SURPRISED|>": "😮",
350+
"<|Cry|>": "😭",
351+
"<|EMO_UNKNOWN|>": "",
352+
"<|Sneeze|>": "🤧",
353+
"<|Breath|>": "",
354+
"<|Cough|>": "😷",
355+
"<|Sing|>": "",
356+
"<|Speech_Noise|>": "",
357+
"<|withitn|>": "",
358+
"<|woitn|>": "",
359+
"<|GBG|>": "",
360+
"<|Event_UNK|>": "",
361+
}
362+
363+
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
364+
event_set = {
365+
"🎼",
366+
"👏",
367+
"😀",
368+
"😭",
369+
"🤧",
370+
"😷",
371+
}
372+
373+
374+
def format_str_v2(s):
375+
sptk_dict = {}
376+
for sptk in emoji_dict:
377+
sptk_dict[sptk] = s.count(sptk)
378+
s = s.replace(sptk, "")
379+
emo = "<|NEUTRAL|>"
380+
for e in emo_dict:
381+
if sptk_dict[e] > sptk_dict[emo]:
382+
emo = e
383+
for e in event_dict:
384+
if sptk_dict[e] > 0:
385+
s = event_dict[e] + s
386+
s = s + emo_dict[emo]
387+
388+
for emoji in emo_set.union(event_set):
389+
s = s.replace(" " + emoji, emoji)
390+
s = s.replace(emoji + " ", emoji)
391+
return s.strip()
392+
393+
394+
def rich_transcription_postprocess(s):
395+
def get_emo(s):
396+
return s[-1] if s[-1] in emo_set else None
397+
398+
def get_event(s):
399+
return s[0] if s[0] in event_set else None
400+
401+
s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
402+
for lang in lang_dict:
403+
s = s.replace(lang, "<|lang|>")
404+
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
405+
new_s = " " + s_list[0]
406+
cur_ent_event = get_event(new_s)
407+
for i in range(1, len(s_list)):
408+
if len(s_list[i]) == 0:
409+
continue
410+
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
411+
s_list[i] = s_list[i][1:]
412+
# else:
413+
cur_ent_event = get_event(s_list[i])
414+
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
415+
new_s = new_s[:-1]
416+
new_s += s_list[i].strip().lstrip()
417+
new_s = new_s.replace("The.", " ")
418+
return new_s.strip()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from pathlib import Path
2+
from typing import Iterable
3+
from typing import List
4+
from typing import Union
5+
6+
import sentencepiece as spm
7+
8+
9+
class SentencepiecesTokenizer:
10+
def __init__(self, bpemodel: Union[Path, str], **kwargs):
11+
super().__init__(**kwargs)
12+
self.bpemodel = str(bpemodel)
13+
# NOTE(kamo):
14+
# Don't build SentencePieceProcessor in __init__()
15+
# because it's not picklable and it may cause following error,
16+
# "TypeError: can't pickle SwigPyObject objects",
17+
# when giving it as argument of "multiprocessing.Process()".
18+
self.sp = None
19+
self._build_sentence_piece_processor()
20+
21+
def __repr__(self):
22+
return f'{self.__class__.__name__}(model="{self.bpemodel}")'
23+
24+
def _build_sentence_piece_processor(self):
25+
# Build SentencePieceProcessor lazily.
26+
if self.sp is None:
27+
self.sp = spm.SentencePieceProcessor()
28+
self.sp.load(self.bpemodel)
29+
30+
def text2tokens(self, line: str) -> List[str]:
31+
self._build_sentence_piece_processor()
32+
return self.sp.EncodeAsPieces(line)
33+
34+
def tokens2text(self, tokens: Iterable[str]) -> str:
35+
self._build_sentence_piece_processor()
36+
return self.sp.DecodePieces(list(tokens))
37+
38+
def encode(self, line: str, **kwargs) -> List[int]:
39+
self._build_sentence_piece_processor()
40+
return self.sp.EncodeAsIds(line)
41+
42+
def decode(self, line: List[int], **kwargs):
43+
self._build_sentence_piece_processor()
44+
return self.sp.DecodeIds(line)
45+
46+
def get_vocab_size(self):
47+
return self.sp.GetPieceSize()
48+
49+
def ids2tokens(self, *args, **kwargs):
50+
return self.decode(*args, **kwargs)
51+
52+
def tokens2ids(self, *args, **kwargs):
53+
return self.encode(*args, **kwargs)

runtime/python/onnxruntime/setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@ def get_readme():
3131
"librosa",
3232
"onnxruntime>=1.7.0",
3333
"scipy",
34-
"numpy>=1.19.3",
34+
"numpy<=1.26.4",
3535
"kaldi-native-fbank",
3636
"PyYAML>=5.1.2",
3737
"onnx",
38+
"sentencepiece",
3839
],
3940
packages=[MODULE_NAME, f"{MODULE_NAME}.utils"],
4041
keywords=["funasr,asr"],

0 commit comments

Comments
 (0)