88import sys
99
1010sys .path .insert (0 , '../' )
11- import west .models .osum_echat .patch4generate # make patch for generate
11+ from west .models .osum_echat .patch4generate import do_patch
1212
13+ do_patch ()
1314
14- def get_feat_from_wav_path (input_wav_path , device : torch .device = torch .device ('cuda' )):
15+
16+ def get_feat_from_wav_path (input_wav_path ,
17+ device : torch .device = torch .device ('cuda' )):
1518 """..."""
1619 waveform , sample_rate = torchaudio .load (input_wav_path )
1720 if waveform .shape [0 ] > 1 :
1821 waveform = torch .mean (waveform , dim = 0 , keepdim = True )
19- resampler = torchaudio .transforms .Resample (orig_freq = sample_rate , new_freq = 16000 )
22+ resampler = torchaudio .transforms .Resample (orig_freq = sample_rate ,
23+ new_freq = 16000 )
2024 waveform = resampler (waveform )
2125 waveform = waveform .squeeze (0 )
2226 sample_rate = 16000
2327 window = torch .hann_window (400 )
2428 stft = torch .stft (waveform , 400 , 160 , window = window , return_complex = True )
25- magnitudes = stft [..., :- 1 ].abs () ** 2
26- filters = torch .from_numpy (librosa .filters .mel (sr = sample_rate , n_fft = 400 , n_mels = 80 ))
29+ magnitudes = stft [..., :- 1 ].abs ()** 2
30+ filters = torch .from_numpy (
31+ librosa .filters .mel (sr = sample_rate , n_fft = 400 , n_mels = 80 ))
2732 mel_spec = filters @ magnitudes
2833 log_spec = torch .clamp (mel_spec , min = 1e-10 ).log10 ()
2934 log_spec = torch .maximum (log_spec , log_spec .max () - 8.0 )
@@ -39,19 +44,23 @@ def get_feat_from_wav_path(input_wav_path, device: torch.device = torch.device('
3944 from huggingface_hub import hf_hub_download
4045
4146 # For natural language think model in west
42- ckpt_path = hf_hub_download (repo_id = "ASLP-lab/OSUM-EChat" , filename = "language_think_west.pt" )
47+ ckpt_path = hf_hub_download (repo_id = "ASLP-lab/OSUM-EChat" ,
48+ filename = "language_think_west.pt" )
4349 osum_config_path = "../examples/aishell/asr/conf/osum_echat.json"
4450 config_new = AutoConfig .from_pretrained (osum_config_path )
4551 osum_model = AutoModel .from_config (config_new )
4652 osum_model .eval ()
4753 osum_model .to ('cuda' )
48- missing_keys , unexpected_keys = osum_model .load_state_dict (torch .load (ckpt_path , map_location = "cpu" ), strict = False )
54+ missing_keys , unexpected_keys = osum_model .load_state_dict (torch .load (
55+ ckpt_path , map_location = "cpu" ),
56+ strict = False )
4957 for key in missing_keys :
5058 print ("missing tensor: {}" .format (key ))
5159 for key in unexpected_keys :
5260 print ("unexpected tensor: {}" .format (key ))
5361 print (osum_model )
5462 test_wav_path = "./data/test_wave4osumechat.wav"
5563 fake_wav , faek_wav_lens = get_feat_from_wav_path (test_wav_path )
56- osum_output = osum_model .generate (audio_features = fake_wav , audio_features_lengths = faek_wav_lens )
64+ osum_output = osum_model .generate (audio_features = fake_wav ,
65+ audio_features_lengths = faek_wav_lens )
5766 print (osum_output )
0 commit comments