Skip to content

Commit fed5804

Browse files
chwma0changwei.ma-halo
andauthored
[touchtts] fix bug and support chunk-mask for touch_flow (#77)
Co-authored-by: changwei.ma-halo <[email protected]>
1 parent 018eede commit fed5804

File tree

4 files changed

+172
-4
lines changed

4 files changed

+172
-4
lines changed

west/models/touch_flow/configuration_touch_flow.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,38 @@ def __init__(
2121
hidden_size: int = 0,
2222
inference_cfg_rate: float = 0.7,
2323
n_timesteps: int = 5,
24+
max_speech_duration: float = 30,
25+
min_speech_duration: float = 0.2,
26+
decoding_chunk_size: int = 0,
27+
enable_full_context: bool = True,
28+
max_chunk_size: int = 86,
29+
num_decoding_left_chunks: int = 0,
30+
static_chunk_size: int = -1,
31+
use_dynamic_chunk: bool = False,
32+
use_dynamic_left_chunk: bool = False,
2433
**kwargs,
2534
):
2635
super().__init__(**kwargs)
2736
self.llm_model_name_or_path = llm_model_name_or_path
2837
self.s3tokenizer_model_name_or_path = s3tokenizer_model_name_or_path
2938
self.speaker_model_path = speaker_model_path
39+
self.text_tokenizer_path = text_tokenizer_path
3040
self.num_speech_tokens = num_speech_tokens
3141
self.t_scheduler = t_scheduler
3242
self.sigma_min = sigma_min
3343
self.training_cfg_rate = training_cfg_rate
3444
self.hidden_size = hidden_size
3545
self.inference_cfg_rate = inference_cfg_rate
3646
self.n_timesteps = n_timesteps
47+
self.max_speech_duration = max_speech_duration
48+
self.min_speech_duration = min_speech_duration
49+
self.use_dynamic_chunk = use_dynamic_chunk
50+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
51+
self.decoding_chunk_size = decoding_chunk_size
52+
self.static_chunk_size = static_chunk_size
53+
self.num_decoding_left_chunks = num_decoding_left_chunks
54+
self.enable_full_context = enable_full_context
55+
self.max_chunk_size = max_chunk_size
3756

3857

3958
__all__ = ["TouchFlowConfig"]

west/models/touch_flow/extractor_touch_flow.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ def __init__(self, tokenizer, model_config, inference=False):
2121
def extract(self, item):
2222
import s3tokenizer
2323
waveform, sample_rate = torchaudio.load(item['wav'])
24+
duration = waveform.size(1) / sample_rate
25+
if not self.inference and (
26+
duration < self.model_config.min_speech_duration
27+
or duration > self.model_config.max_speech_duration):
28+
return None
29+
2430
audio = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
2531
audio_22k = torchaudio.transforms.Resample(sample_rate, 22050)(waveform)
2632
mel_vocoder = mel_spectrogram(audio_22k,

west/models/touch_flow/modeling_touch_flow.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
1313
PreTrainedModel)
1414

15-
from west.utils.mask import make_pad_mask, non_causal_mask
15+
from west.utils.mask import (add_optional_chunk_mask, make_pad_mask,
16+
mask_to_bias)
1617
from west.utils.utils import freeze_module
1718

1819
from .configuration_touch_flow import TouchFlowConfig
@@ -148,7 +149,17 @@ def forward(
148149
dim=-1) # (B, T, 5*M)
149150
inputs = self.input_projector(inputs) # (B, T, D)
150151
mask = ~make_pad_mask(mel_vocoder_lengths).to(device) # (B, T)
151-
att_mask = non_causal_mask(mel_vocoder_lengths).to(device) # (B, T, T)
152+
att_mask = add_optional_chunk_mask(
153+
xs=token_cond, masks=mask.unsqueeze(1),
154+
use_dynamic_chunk=self.config.use_dynamic_chunk,
155+
use_dynamic_left_chunk=self.config.use_dynamic_left_chunk,
156+
decoding_chunk_size=self.config.decoding_chunk_size,
157+
static_chunk_size=self.config.static_chunk_size,
158+
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
159+
enable_full_context=self.config.enable_full_context,
160+
max_chunk_size=self.config.max_chunk_size) # (B, T, T)
161+
if self.llm.config._attn_implementation == "sdpa":
162+
att_mask = mask_to_bias(att_mask, token_cond.dtype)
152163
att_mask = att_mask.unsqueeze(1).float() # (B, 1, T, T)
153164
result = self.llm.model(inputs_embeds=inputs,
154165
attention_mask=att_mask,
@@ -213,7 +224,18 @@ def inference(
213224
x_in[0:1, :, 3 * M:4 * M] = spk_cond
214225
x_in[0:1, :, 4 * M:5 * M] = mel_cond
215226
vocoder_lengths = torch.tensor([T], dtype=torch.long, device=device)
216-
att_mask = non_causal_mask(vocoder_lengths).to(device) # (B, T, T)
227+
mask = ~make_pad_mask(vocoder_lengths).to(device) # (B, T)
228+
att_mask = add_optional_chunk_mask(
229+
xs=token_cond, masks=mask.unsqueeze(1),
230+
use_dynamic_chunk=self.config.use_dynamic_chunk,
231+
use_dynamic_left_chunk=self.config.use_dynamic_left_chunk,
232+
decoding_chunk_size=self.config.decoding_chunk_size,
233+
static_chunk_size=self.config.static_chunk_size,
234+
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
235+
enable_full_context=self.config.enable_full_context,
236+
max_chunk_size=self.config.max_chunk_size) # (B, T, T)
237+
if self.llm.config._attn_implementation == "sdpa":
238+
att_mask = mask_to_bias(att_mask, token_cond.dtype)
217239
att_mask = att_mask.unsqueeze(1).float() # (B, 1, T, T)
218240
for step in range(1, len(t_span)):
219241
x_in[:, :, 0:M] = pt
@@ -239,6 +261,6 @@ def inference(
239261

240262
def init_tokenizer(self):
241263
tokenizer = AutoTokenizer.from_pretrained(
242-
self.config.llm_model_name_or_path)
264+
self.config.text_tokenizer_path)
243265
tokenizer.bos_token = "<|im_start|>"
244266
return tokenizer

west/utils/mask.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,125 @@ def non_causal_mask(lengths):
4848
return mask
4949

5050

51+
def subsequent_chunk_mask(
52+
size: int,
53+
chunk_size: int,
54+
num_left_chunks: int = -1,
55+
device: torch.device = torch.device("cpu"),
56+
) -> torch.Tensor:
57+
"""Create mask for subsequent steps (size, size) with chunk size,
58+
this is for streaming encoder
59+
60+
Args:
61+
size (int): size of mask
62+
chunk_size (int): size of chunk
63+
num_left_chunks (int): number of left chunks
64+
<0: use full chunk
65+
>=0: use num_left_chunks
66+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
67+
68+
Returns:
69+
torch.Tensor: mask
70+
71+
Examples:
72+
>>> subsequent_chunk_mask(4, 2)
73+
[[1, 1, 0, 0],
74+
[1, 1, 0, 0],
75+
[1, 1, 1, 1],
76+
[1, 1, 1, 1]]
77+
"""
78+
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
79+
for i in range(size):
80+
if num_left_chunks < 0:
81+
start = 0
82+
else:
83+
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
84+
ending = min((i // chunk_size + 1) * chunk_size, size)
85+
ret[i, start:ending] = True
86+
return ret
87+
88+
89+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
90+
assert mask.dtype == torch.bool
91+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
92+
mask = mask.to(dtype)
93+
# attention mask bias
94+
# NOTE(Mddct): torch.finfo jit issues
95+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
96+
mask = (1.0 - mask) * -1.0e+10
97+
return mask
98+
99+
100+
def add_optional_chunk_mask(xs: torch.Tensor,
101+
masks: torch.Tensor,
102+
use_dynamic_chunk: bool,
103+
use_dynamic_left_chunk: bool,
104+
decoding_chunk_size: int,
105+
static_chunk_size: int,
106+
num_decoding_left_chunks: int,
107+
enable_full_context: bool = True,
108+
max_chunk_size: int = 25):
109+
""" Apply optional mask for encoder.
110+
111+
Args:
112+
xs (torch.Tensor): padded input, (B, L, D), L for max length
113+
mask (torch.Tensor): mask for xs, (B, 1, L)
114+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
115+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
116+
training.
117+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
118+
0: default for training, use random dynamic chunk.
119+
<0: for decoding, use full chunk.
120+
>0: for decoding, use fixed chunk size as set.
121+
static_chunk_size (int): chunk size for static chunk training/decoding
122+
if it's greater than 0, if use_dynamic_chunk is true,
123+
this parameter will be ignored
124+
num_decoding_left_chunks: number of left chunks, this is for decoding,
125+
the chunk size is decoding_chunk_size.
126+
>=0: use num_decoding_left_chunks
127+
<0: use all left chunks
128+
enable_full_context (bool):
129+
True: chunk size is [1, max_chunk_size] or full context(max_len)
130+
False: chunk size ~ U[1, max_chunk_size]
131+
132+
Returns:
133+
torch.Tensor: chunk mask of the input xs.
134+
"""
135+
# Whether to use chunk mask or not
136+
if use_dynamic_chunk:
137+
max_len = xs.size(1)
138+
if decoding_chunk_size < 0:
139+
chunk_size = max_len
140+
num_left_chunks = -1
141+
elif decoding_chunk_size > 0:
142+
chunk_size = decoding_chunk_size
143+
num_left_chunks = num_decoding_left_chunks
144+
else:
145+
# chunk_size maybe [1, max_chunk_size] or max_len if full context.
146+
chunk_size = torch.randint(1, max_len, (1, )).item()
147+
num_left_chunks = -1
148+
if chunk_size > max_len // 2 and enable_full_context:
149+
chunk_size = max_len
150+
else:
151+
chunk_size = chunk_size % max_chunk_size + 1
152+
if use_dynamic_left_chunk:
153+
max_left_chunks = (max_len - 1) // chunk_size
154+
num_left_chunks = torch.randint(0, max_left_chunks,
155+
(1, )).item()
156+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
157+
num_left_chunks,
158+
xs.device) # (L, L)
159+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
160+
chunk_masks = masks & chunk_masks # (B, L, L)
161+
elif static_chunk_size > 0:
162+
num_left_chunks = num_decoding_left_chunks
163+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
164+
num_left_chunks,
165+
xs.device) # (L, L)
166+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
167+
chunk_masks = masks & chunk_masks # (B, L, L)
168+
else:
169+
chunk_masks = masks
170+
return chunk_masks
171+
51172
# print(non_causal_mask(torch.tensor([2, 3, 4], dtype=torch.long)))

0 commit comments

Comments
 (0)