Skip to content

Commit 9fd9e45

Browse files
committed
add ability to indicate input sample freq for the prime wav being offered to AudioLM. may need to make this required, unless if prompt audio is offered as an audio path
1 parent 053bfe8 commit 9fd9e45

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,8 @@ $ accelerate launch train.py
345345

346346
- [ ] design a hierarchical coarse and fine transformer
347347
- [ ] investigate <a href="https://openreview.net/forum?id=H-VlwsYvVi">spec decoding</a>, first test in x-transformers, then port over if applicable
348+
- [ ] accept prime wave in `AudioLM` as a path to an audio file, and auto resample for semantic vs acoustic
349+
348350
- [ ] redo the positional embeddings in the presence of groups in residual vq
349351
- [ ] test with speech synthesis for starters
350352
- [ ] cli tool, something like `audiolm generate <wav.file | text>` and save generated wav file to local directory

audiolm_pytorch/audiolm_pytorch.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,6 +1194,7 @@ def generate(
11941194
text: Optional[List[str]] = None,
11951195
text_embeds = None,
11961196
prime_wave = None,
1197+
prime_wave_input_sample_hz = None,
11971198
prime_ids = None,
11981199
batch_size = 1,
11991200
cond_scale = 3,
@@ -1209,7 +1210,11 @@ def generate(
12091210
if exists(prime_wave):
12101211
assert not exists(prime_ids)
12111212
assert exists(self.wav2vec)
1212-
ids = self.wav2vec(prime_wave, flatten = False)
1213+
ids = self.wav2vec(
1214+
prime_wave,
1215+
flatten = False,
1216+
input_sample_hz = prime_wave_input_sample_hz
1217+
)
12131218
elif exists(prime_ids):
12141219
ids = prime_ids
12151220
else:
@@ -1375,6 +1380,7 @@ def generate(
13751380
*,
13761381
semantic_token_ids,
13771382
prime_wave: Optional[Tensor] = None,
1383+
prime_wave_input_sample_hz = None,
13781384
prime_coarse_token_ids: Optional[Tensor] = None,
13791385
text: Optional[List[str]] = None,
13801386
text_embeds = None,
@@ -1400,7 +1406,13 @@ def generate(
14001406
assert exists(self.codec)
14011407
with torch.inference_mode():
14021408
self.codec.eval()
1403-
_, indices, _ = self.codec(prime_wave, return_encoded = True)
1409+
1410+
_, indices, _ = self.codec(
1411+
prime_wave,
1412+
return_encoded = True,
1413+
input_sample_hz = prime_wave_input_sample_hz
1414+
)
1415+
14041416
coarse_token_ids = indices[..., :self.num_coarse_quantizers]
14051417
coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
14061418
else:
@@ -1621,6 +1633,7 @@ def generate(
16211633
*,
16221634
coarse_token_ids,
16231635
prime_wave: Optional[Tensor] = None,
1636+
prime_wave_input_sample_hz = None,
16241637
prime_fine_token_ids: Optional[Tensor] = None,
16251638
text: Optional[List[str]] = None,
16261639
text_embeds = None,
@@ -1657,7 +1670,11 @@ def generate(
16571670
assert exists(self.codec)
16581671
with torch.inference_mode():
16591672
self.codec.eval()
1660-
_, token_ids, _ = self.codec(prime_wave, return_encoded = True)
1673+
_, token_ids, _ = self.codec(
1674+
prime_wave,
1675+
return_encoded = True,
1676+
input_sample_hz = prime_wave_input_sample_hz
1677+
)
16611678

16621679
fine_token_ids = token_ids[..., self.num_coarse_quantizers:]
16631680
fine_token_ids = rearrange(fine_token_ids, 'b ... -> b (...)')
@@ -1883,6 +1900,7 @@ def forward(
18831900
text: Optional[List[str]] = None,
18841901
text_embeds: Optional[Tensor] = None,
18851902
prime_wave = None,
1903+
prime_wave_input_sample_hz = None,
18861904
max_length = 2048,
18871905
return_coarse_generated_wave = False,
18881906
mask_out_generated_fine_tokens = False
@@ -1900,13 +1918,15 @@ def forward(
19001918
text_embeds = text_embeds if self.semantic_has_condition else None,
19011919
batch_size = batch_size,
19021920
prime_wave = prime_wave,
1921+
prime_wave_input_sample_hz = prime_wave_input_sample_hz,
19031922
max_length = max_length
19041923
)
19051924

19061925
coarse_token_ids_or_recon_wave = self.coarse.generate(
19071926
text_embeds = text_embeds if self.coarse_has_condition else None,
19081927
semantic_token_ids = semantic_token_ids,
19091928
prime_wave = prime_wave,
1929+
prime_wave_input_sample_hz = prime_wave_input_sample_hz,
19101930
reconstruct_wave = return_coarse_generated_wave
19111931
)
19121932

@@ -1917,6 +1937,7 @@ def forward(
19171937
text_embeds = text_embeds if self.fine_has_condition else None,
19181938
coarse_token_ids = coarse_token_ids_or_recon_wave,
19191939
prime_wave = prime_wave,
1940+
prime_wave_input_sample_hz = prime_wave_input_sample_hz,
19201941
reconstruct_wave = True,
19211942
mask_out_generated_fine_tokens = mask_out_generated_fine_tokens
19221943
)

audiolm_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.22'
1+
__version__ = '1.2.23'

0 commit comments

Comments
 (0)