Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions west/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,21 @@ def set_epoch(self, epoch):

def filter(self, data):
for item in data:
waveform, sample_rate = torchaudio.load(item['wav'])
item['wav'] = waveform
item['sample_rate'] = sample_rate
if self.inference:
yield item
else:
duration = waveform.shape[1] / sample_rate
if duration > self.data_args.max_speech_seconds:
continue
if duration < self.data_args.min_speech_seconds:
continue
if 'messages' in item: # OpenAI role-content based SFT data
yield item
else: # Speech pretraining data
waveform, sample_rate = torchaudio.load(item['wav'])
item['wav'] = waveform
item['sample_rate'] = sample_rate
if self.inference:
yield item
else:
duration = waveform.shape[1] / sample_rate
if duration > self.data_args.max_speech_seconds:
continue
if duration < self.data_args.min_speech_seconds:
continue
yield item

def _read_one(self):
try:
Expand Down
78 changes: 58 additions & 20 deletions west/models/touch_asu/extractor_touch_asu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,70 @@ class ExtractorTouchASU(Extractor):

def extract(self, item):
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
audio = torchaudio.transforms.Resample(item['sample_rate'],
16000)(item['wav'])
audio = audio * (1 << 15)
# mel: (T, 80)
mel = torchaudio.compliance.kaldi.fbank(audio,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
energy_floor=0.0,
sample_frequency=16000)
# TODO(Binbin Zhang): Refine to instruction + <AUDIO>
ids_audio = [0] * (mel.size(0) // 8) # 8 is the final subsampling rate
tgt_audio = [IGNORE_TOKEN_ID] * len(ids_audio)
instruction = 'Transcribe the speech'
content = item['txt']
t0 = '<|im_start|>system\n' + \
'You are a helpful assistant<|im_end|>\n' + \
'<|im_start|>user\n' + instruction + '<|audio_bos|>'
if 'messages' in item: # OpenAI role-content based SFT data
messages = item['messages']
else: # Speech pretraining data
messages = [
{
'role':
'user',
'content': [{
'type': 'text',
'text': 'Transcribe the Speech'
}, {
'type': 'audio',
'audio': item['wav']
}]
},
{
'role': 'assistant',
'content': item['txt']
},
]

t0 = '<|im_start|>user\n'
t1 = '<|audio_eos|><|im_end|>\n' + '<|im_start|>assistant\n'
t2 = ''
for msg in messages:
if msg['role'] == 'system':
t0 += msg['content']
elif msg['role'] == 'user':
if isinstance(msg['content'], dict):
assert msg['content']['type'] == 'audio'
t0 += '<|audio_bos|>'
audio = msg['content']['audio']
elif isinstance(msg['content'], list):
# Here we assume the 1st is text, 2nd is audio
assert len(msg['content']) == 2
t0 += msg['content'][0]['text']
t0 += '<|audio_bos|>'
audio = msg['content'][1]['audio']
# Feature extraction
if isinstance(audio, str): # path
wav, sample_rate = torchaudio.load(audio)
else:
wav, sample_rate = item['wav'], item['sample_rate']
wav = torchaudio.transforms.Resample(sample_rate, 16000)(wav)
wav = wav * (1 << 15)
mel = torchaudio.compliance.kaldi.fbank(wav,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
energy_floor=0.0,
sample_frequency=16000)
# Here 8 is the final subsampling rate
ids_audio = [0] * (mel.size(0) // 8)
tgt_audio = [IGNORE_TOKEN_ID] * len(ids_audio)

elif msg['role'] == 'assistant':
t2 = msg['content'] + '<|im_end|>\n'
# TODO(Binbin Zhang): Mutil-turn support
ids0 = self.tokenizer.encode(t0)
ids1 = self.tokenizer.encode(t1)
ids = [self.tokenizer.bos_token_id] + ids0 + ids_audio + ids1
tgt = [self.tokenizer.bos_token_id] + ids0 + tgt_audio + ids1
if not self.inference:
t2 = content + '<|im_end|>\n'
ids2 = self.tokenizer.encode(t2)
ids = ids + ids2 + [self.tokenizer.eos_token_id]
tgt = tgt + ids2 + [self.tokenizer.eos_token_id]
Expand Down