Skip to content

Commit 6308d54

Browse files
authored
[cli] support on-the-fly training by loading pt model as nn.Module (#428)
1 parent 1d4164b commit 6308d54

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

wespeaker/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from wespeaker.cli.speaker import load_model # noqa
2-
from wespeaker.cli.speaker import load_model_local # noqa
2+
from wespeaker.cli.speaker import load_model_pt # noqa

wespeaker/cli/speaker.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,7 @@ class Speaker:
3939

4040
def __init__(self, model_dir: str):
4141
set_seed()
42-
43-
config_path = os.path.join(model_dir, 'config.yaml')
44-
model_path = os.path.join(model_dir, 'avg_model.pt')
45-
with open(config_path, 'r') as fin:
46-
configs = yaml.load(fin, Loader=yaml.FullLoader)
47-
self.model = get_speaker_model(
48-
configs['model'])(**configs['model_args'])
49-
load_checkpoint(self.model, model_path)
50-
self.model.eval()
42+
self.model = load_model_pt(model_dir)
5143
self.vad = load_silero_vad()
5244
self.table = {}
5345
self.resample_rate = 16000
@@ -293,13 +285,33 @@ def make_rttm(self, merged_segment_to_labels, outfile):
293285
float(end) - float(begin), label))
294286

295287

296-
def load_model(language: str) -> Speaker:
297-
model_path = Hub.get_model(language)
298-
return Speaker(model_path)
288+
def load_model(
289+
model_id: str = None,
290+
model_dir: str = None,
291+
) -> Speaker:
292+
if model_dir is None:
293+
model_dir = Hub.get_model(model_id)
294+
return Speaker(model_dir)
299295

300296

301-
def load_model_local(model_dir: str) -> Speaker:
302-
return Speaker(model_dir)
297+
# Load the pytorch pt model which contains all the details.
298+
# And we can use the pt model as a third party pytorch nn.Module for training
299+
def load_model_pt(model_dir: str):
300+
"""There are the following files in the `model_dir`:
301+
- config.yaml: the model config file
302+
- avg_model.pt: the pytorch model file
303+
"""
304+
required_files = ['config.yaml', 'avg_model.pt']
305+
for file in required_files:
306+
if not os.path.exists(os.path.join(model_dir, file)):
307+
raise FileNotFoundError(f"{file} not found in {model_dir}")
308+
# Read config file
309+
with open(os.path.join(model_dir, 'config.yaml'), 'r') as f:
310+
config = yaml.load(f, Loader=yaml.FullLoader)
311+
model = get_speaker_model(config['model'])(**config['model_args'])
312+
load_checkpoint(model, os.path.join(model_dir, 'avg_model.pt'))
313+
model.eval()
314+
return model
303315

304316

305317
def main():
@@ -318,7 +330,7 @@ def main():
318330
else:
319331
model = load_model(args.language)
320332
else:
321-
model = load_model_local(args.pretrain)
333+
model = load_model(model_dir=args.pretrain)
322334
model.set_resample_rate(args.resample_rate)
323335
model.set_vad(args.vad)
324336
model.set_device(args.device)

0 commit comments

Comments
 (0)