diff --git a/west/dataset/dataset.py b/west/dataset/dataset.py index a2abdee..2ea9958 100644 --- a/west/dataset/dataset.py +++ b/west/dataset/dataset.py @@ -78,18 +78,21 @@ def set_epoch(self, epoch): def filter(self, data): for item in data: - waveform, sample_rate = torchaudio.load(item['wav']) - item['wav'] = waveform - item['sample_rate'] = sample_rate - if self.inference: - yield item - else: - duration = waveform.shape[1] / sample_rate - if duration > self.data_args.max_speech_seconds: - continue - if duration < self.data_args.min_speech_seconds: - continue + if 'messages' in item: # OpenAI role-content based SFT data yield item + else: # Speech pretraining data + waveform, sample_rate = torchaudio.load(item['wav']) + item['wav'] = waveform + item['sample_rate'] = sample_rate + if self.inference: + yield item + else: + duration = waveform.shape[1] / sample_rate + if duration > self.data_args.max_speech_seconds: + continue + if duration < self.data_args.min_speech_seconds: + continue + yield item def _read_one(self): try: diff --git a/west/models/touch_asu/extractor_touch_asu.py b/west/models/touch_asu/extractor_touch_asu.py index 6bd1fdd..57561eb 100644 --- a/west/models/touch_asu/extractor_touch_asu.py +++ b/west/models/touch_asu/extractor_touch_asu.py @@ -15,32 +15,70 @@ class ExtractorTouchASU(Extractor): def extract(self, item): IGNORE_TOKEN_ID = LabelSmoother.ignore_index - audio = torchaudio.transforms.Resample(item['sample_rate'], - 16000)(item['wav']) - audio = audio * (1 << 15) - # mel: (T, 80) - mel = torchaudio.compliance.kaldi.fbank(audio, - num_mel_bins=80, - frame_length=25, - frame_shift=10, - dither=0.0, - energy_floor=0.0, - sample_frequency=16000) - # TODO(Binbin Zhang): Refine to instruction +