Skip to content

Commit 4b31cb1

Browse files
authored
Merge pull request #35 from xingchensong/longaudio
[feat] support long audio
2 parents bef31a6 + 605a04b commit 4b31cb1

File tree

10 files changed

+1161
-79
lines changed

10 files changed

+1161
-79
lines changed

.flake8

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
[flake8]
2+
# Suggested config from pytorch that we can adapt
3+
select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2
4+
max-line-length = 120
5+
# C408 ignored because we like the dict keyword argument syntax
6+
# E501 is not flexible enough, we're using B950 instead
7+
# N812 ignored because import torch.nn.functional as F is PyTorch convention
8+
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
9+
# E731 allow usage of assigning lambda expressions
10+
# N803,N806 allow caps and mixed case in function params. This is to work with Triton kernel coding style.
11+
ignore =
12+
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806
13+
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
14+
# to line this up with executable bit
15+
EXE001,
16+
# these ignores are from flake8-bugbear; please fix!
17+
B007,B008,
18+
optional-ascii-coding = True
19+
exclude =
20+
./.git,
21+
./docs
22+
./build
23+
./scripts,
24+
./venv,
25+
*.pyi
26+
.pre-commit-config.yaml
27+
*.md
28+
.flake8
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
name: CPU Unit Test
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
8+
concurrency:
9+
group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
10+
cancel-in-progress: true
11+
12+
jobs:
13+
unit-test:
14+
runs-on: ${{ matrix.os }}
15+
strategy:
16+
max-parallel: 20
17+
matrix:
18+
os: [ubuntu-22.04]
19+
python-version: [3.10.16]
20+
steps:
21+
- name: Cache Python Packages
22+
uses: actions/cache@v4
23+
with:
24+
path: ~/.cache/pip
25+
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
26+
- name: Setup Python
27+
uses: actions/setup-python@v4
28+
with:
29+
python-version: ${{ matrix.python-version }}
30+
architecture: x64
31+
- name: Fetch S3Tokenizer
32+
uses: actions/checkout@v4
33+
with:
34+
fetch-depth: 0
35+
ref: ${{ github.event.pull_request.head.ref || github.ref }}
36+
- name: Install S3Tokenizer Dependencies
37+
run: |
38+
set -eux
39+
sudo apt update && sudo apt install -y ffmpeg libsox-dev libsndfile1
40+
pip install -e .
41+
- name: Run Pytest
42+
run: |
43+
set -eux
44+
pip install pytest onnxruntime
45+
pytest --version
46+
PYTHONPATH="${PYTHONPATH:-}:$(pwd)" pytest test/ -q
47+
if [ $? != 0 ]; then exit 1; fi

README.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ This repository undertakes a reverse engineering of the S3Tokenizer, offering:
1414
2. High-throughput (distributed) batch inference, achieving a ~790x speedup compared to the original inference pipeline in [[cosyvoice/tools/extract_speech_token.py]](https://github.com/FunAudioLLM/CosyVoice/blob/main/tools/extract_speech_token.py).
1515
3. The capability to perform online speech code extraction during SpeechLLM training.
1616

17-
## Supported Models 🔥
18-
- [x] [S3Tokenizer V1 50hz](https://modelscope.cn/models/iic/CosyVoice-300M)
19-
- [x] [S3Tokenizer V1 25hz](https://modelscope.cn/models/iic/CosyVoice-300M-25Hz)
20-
- [x] [S3Tokenizer V2 25hz](https://modelscope.cn/models/iic/CosyVoice2-0.5B)
17+
## Supported Models 🔥 && New Features 🎉
18+
- [x] Model: [S3Tokenizer V1 50hz](https://modelscope.cn/models/iic/CosyVoice-300M)
19+
- [x] Model: [S3Tokenizer V1 25hz](https://modelscope.cn/models/iic/CosyVoice-300M-25Hz)
20+
- [x] Model: [S3Tokenizer V2 25hz](https://modelscope.cn/models/iic/CosyVoice2-0.5B)
21+
- [x] Feature: S3Tokenizer now has built-in **long audio processing** capabilities, requiring no additional operations from users!
2122

2223

2324
# Setup
@@ -39,7 +40,7 @@ for wav_path in wav_paths:
3940
audio = s3tokenizer.load_audio(wav_path)
4041
mels.append(s3tokenizer.log_mel_spectrogram(audio))
4142
mels, mels_lens = s3tokenizer.padding(mels)
42-
codes, codes_lens = tokenizer.quantize(mels.cuda(), mels_lens.cuda())
43+
codes, codes_lens = tokenizer.quantize(mels.cuda(), mels_lens.cuda()) # Automatically handles long audio internally!
4344

4445
for i in range(len(wav_paths)):
4546
print(codes[i, :codes_lens[i].item()])
@@ -139,9 +140,9 @@ class SpeechLLM(nn.Module):
139140
</tr>
140141
</table>
141142

143+
# Usage-4: Long Audio Processing (Built-in Automatic Processing)
142144

143-
# TODO
144-
145-
- [x] Usage-1: Offline batch inference
146-
- [x] Usage-2: Distributed offline batch inference via command-line tools
147-
- [x] Usage-3: Online speech code extraction
145+
- **Automatic Detection**: Model automatically detects audio length (>30 seconds triggers long audio processing)
146+
- **Sliding Window**: 30-second window with 4-second overlap, automatically segments long audio
147+
- **Batch Processing**: Internal batch processing of multiple segments for improved efficiency
148+
- **Complete Transparency**: User calling method is identical to short audio

s3tokenizer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828

2929
from .model import S3Tokenizer
3030
from .utils import (load_audio, log_mel_spectrogram, make_non_pad_mask,
31-
mask_to_bias, onnx2torch, padding)
31+
mask_to_bias, onnx2torch, padding, merge_tokenized_segments)
3232

3333
__all__ = [
3434
'load_audio', 'log_mel_spectrogram', 'make_non_pad_mask', 'mask_to_bias',
35-
'onnx2torch', 'padding'
35+
'onnx2torch', 'padding', 'merge_tokenized_segments'
3636
]
3737
_MODELS = {
3838
"speech_tokenizer_v1":

s3tokenizer/cli.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,7 @@ def __getitem__(self, idx):
6161
file_path = self.data[idx]
6262
key = self.keys[idx]
6363
audio = s3tokenizer.load_audio(file_path)
64-
if audio.shape[0] / 16000 > 30:
65-
print(
66-
f'do not support extract speech token for audio longer than 30s, file_path: {file_path}' # noqa
67-
)
68-
mel = torch.zeros(128, 0)
69-
else:
70-
mel = s3tokenizer.log_mel_spectrogram(audio)
64+
mel = s3tokenizer.log_mel_spectrogram(audio)
7165
return key, mel
7266

7367

s3tokenizer/model.py

Lines changed: 207 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from einops import rearrange
2626
from torch import Tensor, nn
2727

28-
from .utils import make_non_pad_mask, mask_to_bias, onnx2torch
28+
from .utils import make_non_pad_mask, mask_to_bias, onnx2torch, merge_tokenized_segments
2929

3030

3131
@dataclass
@@ -236,7 +236,7 @@ def preprocess(self, x: Tensor) -> Tensor:
236236

237237
@torch.inference_mode()
238238
def quantize(self, x: Tensor) -> Tensor:
239-
embed = self.embed.t()
239+
embed = self.embed.t().to(x.dtype)
240240
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed +
241241
embed.pow(2).sum(0, keepdim=True))
242242
embed_ind = dist.max(dim=-1).indices
@@ -287,7 +287,7 @@ def codebook(self):
287287

288288
@torch.inference_mode()
289289
def encode(self, x: Tensor) -> Tensor:
290-
x = F.normalize(x, p=2, dim=-1)
290+
x = F.normalize(x.float(), p=2, dim=-1)
291291
embed_in = self._codebook.encode(x)
292292
return embed_in
293293

@@ -306,6 +306,7 @@ class S3Tokenizer(nn.Module):
306306

307307
def __init__(self, name: str, config: ModelConfig = ModelConfig()):
308308
super().__init__()
309+
self.name = name # Store model name for token_rate determination
309310
self.config = config
310311
self.encoder = AudioEncoder(
311312
self.config.n_mels,
@@ -324,9 +325,209 @@ def forward(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
324325

325326
@torch.inference_mode()
326327
def quantize(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
327-
hidden, code_len = self.encoder(mel, mel_len)
328-
code = self.quantizer.encode(hidden)
329-
return code, code_len
328+
"""
329+
Quantize mel spectrogram to tokens, with automatic long audio handling.
330+
331+
Args:
332+
mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
333+
mel_len: mel length tensor, shape (batch_size,)
334+
335+
Returns:
336+
code: quantized tokens, shape (batch_size, T')
337+
code_len: token length, shape (batch_size,)
338+
"""
339+
# Check if any audio in the batch exceeds 30 seconds
340+
# Assuming 16kHz sample rate and hop_length=160, 30s = 30*16000/160 = 3000 frames
341+
max_frames = 3000
342+
343+
# Check which samples are long audio
344+
long_audio_mask = mel_len > max_frames
345+
346+
if long_audio_mask.any():
347+
# Has long audio - need special processing
348+
return self._quantize_mixed_batch(mel, mel_len, long_audio_mask,
349+
max_frames)
350+
else:
351+
# All short audio - use original method
352+
hidden, code_len = self.encoder(mel, mel_len)
353+
code = self.quantizer.encode(hidden)
354+
return code, code_len
355+
356+
@torch.inference_mode()
357+
def _quantize_mixed_batch(self, mel: Tensor, mel_len: Tensor,
358+
long_audio_mask: Tensor,
359+
max_frames: int) -> Tuple[Tensor, Tensor]:
360+
"""
361+
Handle mixed batch with both short and long audio using unified batch processing.
362+
363+
Args:
364+
mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
365+
mel_len: mel length tensor, shape (batch_size,)
366+
long_audio_mask: boolean mask for long audio, shape (batch_size,)
367+
max_frames: maximum frames for short audio
368+
369+
Returns:
370+
code: quantized tokens, shape (batch_size, T')
371+
code_len: token length, shape (batch_size,)
372+
"""
373+
batch_size = mel.size(0)
374+
375+
# Parameters for sliding window
376+
sample_rate = 16000
377+
hop_length = 160 # Default hop length for mel spectrogram
378+
window_size = 30 # seconds
379+
overlap = 4 # seconds
380+
381+
# Calculate frame-based parameters
382+
frames_per_window = window_size * sample_rate // hop_length # 3000 frames
383+
frames_per_overlap = overlap * sample_rate // hop_length # 400 frames
384+
frames_per_stride = frames_per_window - frames_per_overlap # 2600 frames
385+
386+
# Collect all segments to process (including short and long audio segments)
387+
all_segments = []
388+
all_segments_len = []
389+
segment_info = [
390+
] # Record which audio each segment belongs to and whether it's long audio
391+
392+
# Process all audio in the batch
393+
for batch_idx in range(batch_size):
394+
audio_mel = mel[batch_idx]
395+
audio_mel_len = mel_len[batch_idx]
396+
is_long_audio = long_audio_mask[batch_idx].item()
397+
398+
if not is_long_audio:
399+
# Short audio: process directly as a single segment
400+
segment = audio_mel[:, :audio_mel_len]
401+
seg_len = audio_mel_len.item()
402+
403+
# Pad to max_frames if necessary
404+
if seg_len < frames_per_window:
405+
pad_size = frames_per_window - seg_len
406+
segment = F.pad(segment, (0, pad_size))
407+
408+
all_segments.append(segment)
409+
all_segments_len.append(
410+
torch.tensor(seg_len, device=mel.device))
411+
segment_info.append({
412+
'batch_idx': batch_idx,
413+
'is_long_audio': False,
414+
'segment_idx': 0,
415+
'total_segments': 1
416+
})
417+
else:
418+
# Long audio: split into multiple segments
419+
start = 0
420+
segment_idx = 0
421+
while start < audio_mel_len:
422+
end = min(start + frames_per_window, audio_mel_len)
423+
segment = audio_mel[:, start:end]
424+
425+
seg_len = segment.size(1)
426+
# Pad if necessary
427+
if seg_len < frames_per_window:
428+
pad_size = frames_per_window - seg_len
429+
segment = F.pad(segment, (0, pad_size))
430+
431+
all_segments.append(segment)
432+
all_segments_len.append(
433+
torch.tensor(seg_len, device=mel.device))
434+
segment_info.append({
435+
'batch_idx': batch_idx,
436+
'is_long_audio': True,
437+
'segment_idx': segment_idx,
438+
'total_segments': None # Will be filled later
439+
})
440+
441+
segment_idx += 1
442+
start += frames_per_stride
443+
444+
# Update total_segments info
445+
total_segments = segment_idx
446+
for info in segment_info:
447+
if info['batch_idx'] == batch_idx and info['is_long_audio']:
448+
info['total_segments'] = total_segments
449+
450+
if not all_segments:
451+
# Fallback if no segments
452+
return torch.zeros(batch_size,
453+
0,
454+
dtype=torch.long,
455+
device=mel.device), torch.zeros(
456+
batch_size,
457+
dtype=torch.long,
458+
device=mel.device)
459+
460+
# Unified batch processing for all segments
461+
unified_batch_mel = torch.stack(all_segments)
462+
unified_batch_lens = torch.stack(all_segments_len)
463+
464+
# Process all segments at once
465+
hidden, code_len = self.encoder(unified_batch_mel, unified_batch_lens)
466+
codes = self.quantizer.encode(hidden)
467+
468+
# Reorganize results based on segment_info
469+
results = {} # batch_idx -> (code_tensor, code_len)
470+
471+
for seg_idx, info in enumerate(segment_info):
472+
batch_idx = info['batch_idx']
473+
is_long_audio = info['is_long_audio']
474+
segment_idx = info['segment_idx']
475+
476+
# Get codes for current segment
477+
segment_code = codes[
478+
seg_idx, :code_len[seg_idx].item()].cpu().numpy().tolist()
479+
480+
if not is_long_audio:
481+
# Short audio: use directly
482+
code_tensor = torch.tensor(segment_code,
483+
dtype=torch.long,
484+
device=mel.device)
485+
results[batch_idx] = (code_tensor, len(segment_code))
486+
else:
487+
# Long audio: collect all segments
488+
if batch_idx not in results:
489+
results[batch_idx] = []
490+
results[batch_idx].append(segment_code)
491+
492+
# Process long audio segment merging
493+
for batch_idx in range(batch_size):
494+
if long_audio_mask[batch_idx].item():
495+
# Merge long audio segments
496+
audio_codes = results[batch_idx]
497+
498+
# Determine token rate based on model name
499+
if hasattr(self,
500+
'name') and self.name == "speech_tokenizer_v1":
501+
token_rate = 50
502+
else:
503+
token_rate = 25
504+
505+
merged_codes = merge_tokenized_segments(audio_codes,
506+
overlap=overlap,
507+
token_rate=token_rate)
508+
509+
# Convert to tensor
510+
merged_codes_tensor = torch.tensor(merged_codes,
511+
dtype=torch.long,
512+
device=mel.device)
513+
results[batch_idx] = (merged_codes_tensor, len(merged_codes))
514+
515+
# Construct final output
516+
max_code_len = max(code_info[1] for code_info in results.values())
517+
518+
output_codes = torch.zeros(batch_size,
519+
max_code_len,
520+
dtype=torch.long,
521+
device=mel.device)
522+
output_codes_len = torch.zeros(batch_size,
523+
dtype=torch.long,
524+
device=mel.device)
525+
526+
for batch_idx, (code_tensor, code_len) in results.items():
527+
output_codes[batch_idx, :code_len] = code_tensor
528+
output_codes_len[batch_idx] = code_len
529+
530+
return output_codes, output_codes_len
330531

331532
@property
332533
def device(self):

0 commit comments

Comments
 (0)