@@ -20,12 +20,7 @@ class TouchTTS(PreTrainedModel, GenerationMixin):
2020 def __init__ (self , config : TouchTTSConfig ):
2121 super ().__init__ (config )
2222 llm_config = AutoConfig .from_pretrained (config .llm_model_name_or_path )
23- self .llm = AutoModelForCausalLM .from_pretrained (
24- config .llm_model_name_or_path ,
25- config = llm_config ,
26- torch_dtype = 'auto' ,
27- attn_implementation = "flash_attention_2" , # or "flex_attention"
28- )
23+ self .llm = AutoModelForCausalLM .from_config (config = llm_config )
2924 config .hidden_size = llm_config .hidden_size # for deepseed training
3025 self .speech_tokenizer = s3tokenizer .load_model (
3126 'speech_tokenizer_v1_25hz' , config .s3tokenizer_model_name_or_path )
@@ -126,6 +121,10 @@ def generate(
126121
127122 def init_tokenizer (self ):
128123 tokenizer = AutoTokenizer .from_pretrained (
129- self .config .llm_model_name_or_path )
130- tokenizer .bos_token = "<|im_start|>"
124+ self .config .llm_model_tokenizer_dir , trust_remote_code = True )
125+ if 'Qwen' in self .config .llm_model_tokenizer_dir :
126+ tokenizer .bos_token = tokenizer .eos_token
127+ # Set pad_token if not already set
128+ if tokenizer .pad_token is None :
129+ tokenizer .pad_token = tokenizer .eos_token
131130 return tokenizer
0 commit comments