Skip to content

Commit 874ce03

Browse files
author
dragongeng
committed
[osum-echat] fix bugs for lint
1 parent ee6e462 commit 874ce03

File tree

7 files changed

+499
-300
lines changed

7 files changed

+499
-300
lines changed

.style.yapf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[style]
22
based_on_style = pep8
33
ALLOW_MULTILINE_LAMBDAS = True
4-
COLUMN_LIMIT = 80
4+
COLUMN_LIMIT = 50
55
SPLIT_COMPLEX_COMPREHENSION = False
66
COALESCE_BRACKETS = True

test/test_osum_echat.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,27 @@
88
import sys
99

1010
sys.path.insert(0, '../')
11-
import west.models.osum_echat.patch4generate # make patch for generate
11+
from west.models.osum_echat.patch4generate import do_patch
1212

13+
do_patch()
1314

14-
def get_feat_from_wav_path(input_wav_path, device: torch.device = torch.device('cuda')):
15+
16+
def get_feat_from_wav_path(input_wav_path,
17+
device: torch.device = torch.device('cuda')):
1518
"""..."""
1619
waveform, sample_rate = torchaudio.load(input_wav_path)
1720
if waveform.shape[0] > 1:
1821
waveform = torch.mean(waveform, dim=0, keepdim=True)
19-
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
22+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate,
23+
new_freq=16000)
2024
waveform = resampler(waveform)
2125
waveform = waveform.squeeze(0)
2226
sample_rate = 16000
2327
window = torch.hann_window(400)
2428
stft = torch.stft(waveform, 400, 160, window=window, return_complex=True)
25-
magnitudes = stft[..., :-1].abs() ** 2
26-
filters = torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=400, n_mels=80))
29+
magnitudes = stft[..., :-1].abs()**2
30+
filters = torch.from_numpy(
31+
librosa.filters.mel(sr=sample_rate, n_fft=400, n_mels=80))
2732
mel_spec = filters @ magnitudes
2833
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
2934
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
@@ -39,19 +44,23 @@ def get_feat_from_wav_path(input_wav_path, device: torch.device = torch.device('
3944
from huggingface_hub import hf_hub_download
4045

4146
# For natural language think model in west
42-
ckpt_path = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="language_think_west.pt")
47+
ckpt_path = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat",
48+
filename="language_think_west.pt")
4349
osum_config_path = "../examples/aishell/asr/conf/osum_echat.json"
4450
config_new = AutoConfig.from_pretrained(osum_config_path)
4551
osum_model = AutoModel.from_config(config_new)
4652
osum_model.eval()
4753
osum_model.to('cuda')
48-
missing_keys, unexpected_keys = osum_model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
54+
missing_keys, unexpected_keys = osum_model.load_state_dict(torch.load(
55+
ckpt_path, map_location="cpu"),
56+
strict=False)
4957
for key in missing_keys:
5058
print("missing tensor: {}".format(key))
5159
for key in unexpected_keys:
5260
print("unexpected tensor: {}".format(key))
5361
print(osum_model)
5462
test_wav_path = "./data/test_wave4osumechat.wav"
5563
fake_wav, faek_wav_lens = get_feat_from_wav_path(test_wav_path)
56-
osum_output = osum_model.generate(audio_features=fake_wav, audio_features_lengths=faek_wav_lens)
64+
osum_output = osum_model.generate(audio_features=fake_wav,
65+
audio_features_lengths=faek_wav_lens)
5766
print(osum_output)

west/models/osum_echat/configuration_osum_echat.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ class OSUMEChatConfig(PretrainedConfig):
88
model_type = "osum_echat"
99

1010
def __init__(
11-
self,
12-
llm_model_name_or_path: str = 'Qwen/Qwen2.5-3B-Instruct',
13-
no_init_llm: bool = True,
14-
wenet_model_name_or_path: str = 'whisper-medium',
15-
lora_config: Optional[Dict[str, Any]] = None,
16-
speech_token_num: int = 4097,
17-
**kwargs,
11+
self,
12+
llm_model_name_or_path: str = 'Qwen/Qwen2.5-3B-Instruct',
13+
no_init_llm: bool = True,
14+
wenet_model_name_or_path: str = 'whisper-medium',
15+
lora_config: Optional[Dict[str, Any]] = None,
16+
speech_token_num: int = 4097,
17+
**kwargs,
1818
):
1919
super().__init__(**kwargs)
2020
self.llm_model_name_or_path = llm_model_name_or_path

west/models/osum_echat/cumstom_stop_criteria.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
class ASRLogitsProcessor(LogitsProcessor):
9+
910
def __init__(self, text_token_num: int):
1011
self.text_token_num = text_token_num
1112

@@ -41,7 +42,8 @@ def __init__(self, text_token_num: int, text_eos_id: int):
4142

4243
def __call__(self, input_ids, scores):
4344
print(input_ids.shape)
44-
assert input_ids.size(0) == 1, "ERROR: S2SSpeechLogitsProcessor only support bs=1 now"
45+
assert input_ids.size(
46+
0) == 1, "ERROR: S2SSpeechLogitsProcessor only support bs=1 now"
4547
if self.text_phase:
4648
scores[..., self.text_token_num:] = torch.finfo(scores.dtype).min
4749
else:
@@ -64,16 +66,19 @@ def __init__(self, text_eos_id: int, speech_eos_id: int):
6466
self.text_eos_id = text_eos_id
6567
self.speech_eos_id = speech_eos_id
6668

67-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
69+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor,
70+
**kwargs):
6871
_input_ids = input_ids.flatten().view(-1)
6972
if torch.isin(_input_ids, self.text_eos_id).any():
70-
text_eos_idx = (_input_ids == self.text_eos_id).nonzero(as_tuple=True)[0][0].item()
73+
text_eos_idx = (_input_ids == self.text_eos_id).nonzero(
74+
as_tuple=True)[0][0].item()
7175
if torch.sum(_input_ids[text_eos_idx:] == self.speech_eos_id) > 1:
7276
return True
7377
return False
7478

7579

7680
class MaxTokenStopper(StoppingCriteria):
81+
7782
def __init__(self, max_tokens):
7883
self.max_tokens = max_tokens
7984

@@ -86,11 +91,12 @@ def __call__(self, input_ids, scores, **kwargs):
8691

8792

8893
class InterruptStopper(StoppingCriteria):
94+
8995
def __init__(self):
9096
self.stop = False
9197

9298
def __call__(self, input_ids, scores, **kwargs):
93-
if self.stop == True:
99+
if self.stop:
94100
# self.stop == False # reset
95101
return True
96102
else:

west/models/osum_echat/extractor_osum_echat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Copyright (c) 2025 Xuelong Geng([email protected])
22

3-
43
from west.dataset.extractor import Extractor
54

65

0 commit comments

Comments
 (0)