@@ -1194,6 +1194,7 @@ def generate(
1194
1194
text : Optional [List [str ]] = None ,
1195
1195
text_embeds = None ,
1196
1196
prime_wave = None ,
1197
+ prime_wave_input_sample_hz = None ,
1197
1198
prime_ids = None ,
1198
1199
batch_size = 1 ,
1199
1200
cond_scale = 3 ,
@@ -1209,7 +1210,11 @@ def generate(
1209
1210
if exists (prime_wave ):
1210
1211
assert not exists (prime_ids )
1211
1212
assert exists (self .wav2vec )
1212
- ids = self .wav2vec (prime_wave , flatten = False )
1213
+ ids = self .wav2vec (
1214
+ prime_wave ,
1215
+ flatten = False ,
1216
+ input_sample_hz = prime_wave_input_sample_hz
1217
+ )
1213
1218
elif exists (prime_ids ):
1214
1219
ids = prime_ids
1215
1220
else :
@@ -1375,6 +1380,7 @@ def generate(
1375
1380
* ,
1376
1381
semantic_token_ids ,
1377
1382
prime_wave : Optional [Tensor ] = None ,
1383
+ prime_wave_input_sample_hz = None ,
1378
1384
prime_coarse_token_ids : Optional [Tensor ] = None ,
1379
1385
text : Optional [List [str ]] = None ,
1380
1386
text_embeds = None ,
@@ -1400,7 +1406,13 @@ def generate(
1400
1406
assert exists (self .codec )
1401
1407
with torch .inference_mode ():
1402
1408
self .codec .eval ()
1403
- _ , indices , _ = self .codec (prime_wave , return_encoded = True )
1409
+
1410
+ _ , indices , _ = self .codec (
1411
+ prime_wave ,
1412
+ return_encoded = True ,
1413
+ input_sample_hz = prime_wave_input_sample_hz
1414
+ )
1415
+
1404
1416
coarse_token_ids = indices [..., :self .num_coarse_quantizers ]
1405
1417
coarse_token_ids = rearrange (coarse_token_ids , 'b ... -> b (...)' )
1406
1418
else :
@@ -1621,6 +1633,7 @@ def generate(
1621
1633
* ,
1622
1634
coarse_token_ids ,
1623
1635
prime_wave : Optional [Tensor ] = None ,
1636
+ prime_wave_input_sample_hz = None ,
1624
1637
prime_fine_token_ids : Optional [Tensor ] = None ,
1625
1638
text : Optional [List [str ]] = None ,
1626
1639
text_embeds = None ,
@@ -1657,7 +1670,11 @@ def generate(
1657
1670
assert exists (self .codec )
1658
1671
with torch .inference_mode ():
1659
1672
self .codec .eval ()
1660
- _ , token_ids , _ = self .codec (prime_wave , return_encoded = True )
1673
+ _ , token_ids , _ = self .codec (
1674
+ prime_wave ,
1675
+ return_encoded = True ,
1676
+ input_sample_hz = prime_wave_input_sample_hz
1677
+ )
1661
1678
1662
1679
fine_token_ids = token_ids [..., self .num_coarse_quantizers :]
1663
1680
fine_token_ids = rearrange (fine_token_ids , 'b ... -> b (...)' )
@@ -1883,6 +1900,7 @@ def forward(
1883
1900
text : Optional [List [str ]] = None ,
1884
1901
text_embeds : Optional [Tensor ] = None ,
1885
1902
prime_wave = None ,
1903
+ prime_wave_input_sample_hz = None ,
1886
1904
max_length = 2048 ,
1887
1905
return_coarse_generated_wave = False ,
1888
1906
mask_out_generated_fine_tokens = False
@@ -1900,13 +1918,15 @@ def forward(
1900
1918
text_embeds = text_embeds if self .semantic_has_condition else None ,
1901
1919
batch_size = batch_size ,
1902
1920
prime_wave = prime_wave ,
1921
+ prime_wave_input_sample_hz = prime_wave_input_sample_hz ,
1903
1922
max_length = max_length
1904
1923
)
1905
1924
1906
1925
coarse_token_ids_or_recon_wave = self .coarse .generate (
1907
1926
text_embeds = text_embeds if self .coarse_has_condition else None ,
1908
1927
semantic_token_ids = semantic_token_ids ,
1909
1928
prime_wave = prime_wave ,
1929
+ prime_wave_input_sample_hz = prime_wave_input_sample_hz ,
1910
1930
reconstruct_wave = return_coarse_generated_wave
1911
1931
)
1912
1932
@@ -1917,6 +1937,7 @@ def forward(
1917
1937
text_embeds = text_embeds if self .fine_has_condition else None ,
1918
1938
coarse_token_ids = coarse_token_ids_or_recon_wave ,
1919
1939
prime_wave = prime_wave ,
1940
+ prime_wave_input_sample_hz = prime_wave_input_sample_hz ,
1920
1941
reconstruct_wave = True ,
1921
1942
mask_out_generated_fine_tokens = mask_out_generated_fine_tokens
1922
1943
)
0 commit comments