Skip to content

Commit 22951ab

Browse files
committed
allow for prompt audio to be passed in as prime_wave_path to AudioLM
1 parent 9fd9e45 commit 22951ab

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,10 @@ $ accelerate launch train.py
342342
- [x] allow for specialized relative positional embeddings in fine transformer based on absolute matching positions of quantizers between coarse and fine
343343
- [x] allow for grouped residual vq in soundstream (use `GroupedResidualVQ` from vector-quantize-pytorch lib), from <a href="https://arxiv.org/abs/2305.02765">hifi-codec</a>
344344
- [x] add flash attention with <a href="https://arxiv.org/abs/2305.19466">NoPE</a>
345+
- [x] accept prime wave in `AudioLM` as a path to an audio file, and auto resample for semantic vs acoustic
345346

346347
- [ ] design a hierarchical coarse and fine transformer
347348
- [ ] investigate <a href="https://openreview.net/forum?id=H-VlwsYvVi">spec decoding</a>, first test in x-transformers, then port over if applicable
348-
- [ ] accept prime wave in `AudioLM` as a path to an audio file, and auto resample for semantic vs acoustic
349349

350350
- [ ] redo the positional embeddings in the presence of groups in residual vq
351351
- [ ] test with speech synthesis for starters

audiolm_pytorch/audiolm_pytorch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import torch.nn.functional as F
1111
from torch.nn.utils.rnn import pad_sequence
1212

13+
import torchaudio
14+
1315
from einops import rearrange, repeat, reduce
1416
from einops.layers.torch import Rearrange
1517

@@ -1901,6 +1903,7 @@ def forward(
19011903
text_embeds: Optional[Tensor] = None,
19021904
prime_wave = None,
19031905
prime_wave_input_sample_hz = None,
1906+
prime_wave_path = None,
19041907
max_length = 2048,
19051908
return_coarse_generated_wave = False,
19061909
mask_out_generated_fine_tokens = False
@@ -1911,7 +1914,16 @@ def forward(
19111914
if exists(text):
19121915
text_embeds = self.semantic.embed_text(text)
19131916

1917+
assert not (exists(prime_wave) and exists(prime_wave_path)), 'prompt audio must be given as either `prime_wave: Tensor` or `prime_wave_path: str`'
1918+
19141919
if exists(prime_wave):
1920+
assert exists(prime_wave_input_sample_hz), 'the input sample frequency for the prompt audio must be given as `prime_wave_input_sample_hz: int`'
1921+
prime_wave = prime_wave.to(self.device)
1922+
elif exists(prime_wave_path):
1923+
prime_wave_path = Path(prime_wave_path)
1924+
assert exists(prime_wave_path), f'file does not exist at {str(prime_wave_path)}'
1925+
1926+
prime_wave, prime_wave_input_sample_hz = torchaudio.load(str(prime_wave_path))
19151927
prime_wave = prime_wave.to(self.device)
19161928

19171929
semantic_token_ids = self.semantic.generate(

audiolm_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.23'
1+
__version__ = '1.2.24'

0 commit comments

Comments
 (0)