Skip to content

Commit 7a63d60

Browse files
committed
[touchtts] support touchtts trainning from config
1 parent 988b8da commit 7a63d60

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

examples/libritts/tts/conf/touch_tts_config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"llm_model_name_or_path": "Qwen/Qwen2.5-0.5B-Audio",
3+
"llm_model_tokenizer_dir": "Qwen/Qwen2.5-0.5B-Audio",
34
"model_type": "touch_tts",
45
"transformers_version": "4.52.3",
56
"num_speech_tokens": 4096,

west/models/touch_tts/configuration_touch_tts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class TouchTTSConfig(PretrainedConfig):
99
def __init__(
1010
self,
1111
llm_model_name_or_path: str = 'Qwen/Qwen2-7B',
12+
llm_model_tokenizer_dir: str = 'Qwen/Qwen2-7B',
1213
s3tokenizer_model_name_or_path: str = '',
1314
num_speech_tokens: int = 4096,
1415
hidden_size: int = 0,
@@ -23,6 +24,7 @@ def __init__(
2324
self.hidden_size = hidden_size
2425
self.max_speech_duration = max_speech_duration
2526
self.min_speech_duration = min_speech_duration
27+
self.llm_model_tokenizer_dir = llm_model_tokenizer_dir
2628

2729

2830
__all__ = ["TouchTTSConfig"]

west/models/touch_tts/modeling_touch_tts.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,7 @@ class TouchTTS(PreTrainedModel, GenerationMixin):
2020
def __init__(self, config: TouchTTSConfig):
2121
super().__init__(config)
2222
llm_config = AutoConfig.from_pretrained(config.llm_model_name_or_path)
23-
self.llm = AutoModelForCausalLM.from_pretrained(
24-
config.llm_model_name_or_path,
25-
config=llm_config,
26-
torch_dtype='auto',
27-
attn_implementation="flash_attention_2", # or "flex_attention"
28-
)
23+
self.llm = AutoModelForCausalLM.from_config(config=llm_config)
2924
config.hidden_size = llm_config.hidden_size # for deepseed training
3025
self.speech_tokenizer = s3tokenizer.load_model(
3126
'speech_tokenizer_v1_25hz', config.s3tokenizer_model_name_or_path)
@@ -126,6 +121,10 @@ def generate(
126121

127122
def init_tokenizer(self):
128123
tokenizer = AutoTokenizer.from_pretrained(
129-
self.config.llm_model_name_or_path)
130-
tokenizer.bos_token = "<|im_start|>"
124+
self.config.llm_model_tokenizer_dir, trust_remote_code=True)
125+
if 'Qwen' in self.config.llm_model_tokenizer_dir:
126+
tokenizer.bos_token = tokenizer.eos_token
127+
# Set pad_token if not already set
128+
if tokenizer.pad_token is None:
129+
tokenizer.pad_token = tokenizer.eos_token
131130
return tokenizer

0 commit comments

Comments
 (0)