Skip to content

Commit c9d573b

Browse files
authored
[cli] support firered model (#2772)
1 parent 6521c82 commit c9d573b

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

wenet/cli/hub.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class Hub(object):
8686
assets = {
8787
"wenetspeech": "wenetspeech_u2pp_conformer_exp.tar.gz",
8888
"paraformer": "paraformer.tar.gz",
89+
"firered": "firered.tar.gz",
8990
"punc": "punc.tar.gz"
9091
}
9192

wenet/firered/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525
class FireRedModel(ASRModel):
2626

27+
# FireRedModel only support autogressive decoding
28+
default_decode_method = "attention"
29+
2730
def __init__(
2831
self,
2932
vocab_size: int,

wenet/transformer/asr_model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
class ASRModel(torch.nn.Module):
3636
"""CTC-attention hybrid Encoder-Decoder model"""
3737

38+
# default decoding method for cli
39+
default_decode_method = "attention_rescoring"
40+
3841
def __init__(
3942
self,
4043
vocab_size: int,
@@ -338,14 +341,15 @@ def decode(
338341
return results
339342

340343
def transcribe(self, wav: str):
341-
""" We use attention_rescoring for transcribe"""
344+
"""Transcribe for cli"""
342345
assert hasattr(self, 'compute_feature') # Dynamic inject in cli
343346
assert hasattr(self, 'tokenizer') # Dynamic inject in cli
347+
self.eval()
344348
speech = self.compute_feature(wav)
345349
speech_lengths = torch.tensor([speech.size(0)], device=speech.device)
346350
speech = speech.unsqueeze(0)
347-
results = self.decode(['attention_rescoring'], speech, speech_lengths)
348-
result = results['attention_rescoring'][0]
351+
results = self.decode([self.default_decode_method], speech, speech_lengths)
352+
result = results[self.default_decode_method][0]
349353
result.text = self.tokenizer.detokenize(result.tokens)[0]
350354
return result
351355

0 commit comments

Comments
 (0)