@@ -39,15 +39,7 @@ class Speaker:
3939
4040 def __init__ (self , model_dir : str ):
4141 set_seed ()
42-
43- config_path = os .path .join (model_dir , 'config.yaml' )
44- model_path = os .path .join (model_dir , 'avg_model.pt' )
45- with open (config_path , 'r' ) as fin :
46- configs = yaml .load (fin , Loader = yaml .FullLoader )
47- self .model = get_speaker_model (
48- configs ['model' ])(** configs ['model_args' ])
49- load_checkpoint (self .model , model_path )
50- self .model .eval ()
42+ self .model = load_model_pt (model_dir )
5143 self .vad = load_silero_vad ()
5244 self .table = {}
5345 self .resample_rate = 16000
@@ -293,13 +285,33 @@ def make_rttm(self, merged_segment_to_labels, outfile):
293285 float (end ) - float (begin ), label ))
294286
295287
296- def load_model (language : str ) -> Speaker :
297- model_path = Hub .get_model (language )
298- return Speaker (model_path )
288+ def load_model (
289+ model_id : str = None ,
290+ model_dir : str = None ,
291+ ) -> Speaker :
292+ if model_dir is None :
293+ model_dir = Hub .get_model (model_id )
294+ return Speaker (model_dir )
299295
300296
301- def load_model_local (model_dir : str ) -> Speaker :
302- return Speaker (model_dir )
297+ # Load the pytorch pt model which contains all the details.
298+ # And we can use the pt model as a third party pytorch nn.Module for training
299+ def load_model_pt (model_dir : str ):
300+ """There are the following files in the `model_dir`:
301+ - config.yaml: the model config file
302+ - avg_model.pt: the pytorch model file
303+ """
304+ required_files = ['config.yaml' , 'avg_model.pt' ]
305+ for file in required_files :
306+ if not os .path .exists (os .path .join (model_dir , file )):
307+ raise FileNotFoundError (f"{ file } not found in { model_dir } " )
308+ # Read config file
309+ with open (os .path .join (model_dir , 'config.yaml' ), 'r' ) as f :
310+ config = yaml .load (f , Loader = yaml .FullLoader )
311+ model = get_speaker_model (config ['model' ])(** config ['model_args' ])
312+ load_checkpoint (model , os .path .join (model_dir , 'avg_model.pt' ))
313+ model .eval ()
314+ return model
303315
304316
305317def main ():
@@ -318,7 +330,7 @@ def main():
318330 else :
319331 model = load_model (args .language )
320332 else :
321- model = load_model_local ( args .pretrain )
333+ model = load_model ( model_dir = args .pretrain )
322334 model .set_resample_rate (args .resample_rate )
323335 model .set_vad (args .vad )
324336 model .set_device (args .device )
0 commit comments