File tree Expand file tree Collapse file tree 3 files changed +11
-3
lines changed Expand file tree Collapse file tree 3 files changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -86,6 +86,7 @@ class Hub(object):
8686 assets = {
8787 "wenetspeech" : "wenetspeech_u2pp_conformer_exp.tar.gz" ,
8888 "paraformer" : "paraformer.tar.gz" ,
89+ "firered" : "firered.tar.gz" ,
8990 "punc" : "punc.tar.gz"
9091 }
9192
Original file line number Diff line number Diff line change 2424
2525class FireRedModel (ASRModel ):
2626
27+ # FireRedModel only support autogressive decoding
28+ default_decode_method = "attention"
29+
2730 def __init__ (
2831 self ,
2932 vocab_size : int ,
Original file line number Diff line number Diff line change 3535class ASRModel (torch .nn .Module ):
3636 """CTC-attention hybrid Encoder-Decoder model"""
3737
38+ # default decoding method for cli
39+ default_decode_method = "attention_rescoring"
40+
3841 def __init__ (
3942 self ,
4043 vocab_size : int ,
@@ -338,14 +341,15 @@ def decode(
338341 return results
339342
340343 def transcribe (self , wav : str ):
341- """ We use attention_rescoring for transcribe """
344+ """Transcribe for cli """
342345 assert hasattr (self , 'compute_feature' ) # Dynamic inject in cli
343346 assert hasattr (self , 'tokenizer' ) # Dynamic inject in cli
347+ self .eval ()
344348 speech = self .compute_feature (wav )
345349 speech_lengths = torch .tensor ([speech .size (0 )], device = speech .device )
346350 speech = speech .unsqueeze (0 )
347- results = self .decode (['attention_rescoring' ], speech , speech_lengths )
348- result = results ['attention_rescoring' ][0 ]
351+ results = self .decode ([self . default_decode_method ], speech , speech_lengths )
352+ result = results [self . default_decode_method ][0 ]
349353 result .text = self .tokenizer .detokenize (result .tokens )[0 ]
350354 return result
351355
You can’t perform that action at this time.
0 commit comments