Skip to content

Commit 7f1a7aa

Browse files
authored
[train] add training init from model dir, so we can do further tuning (#36)
1 parent c37f62f commit 7f1a7aa

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

examples/aishell/asr/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ fi
2525

2626
if [ $stage == "train" ] || [ $stage == "all" ]; then
2727
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus west/bin/train.py \
28-
--model_config_path conf/touch_asu_config.json \
28+
--model_config_or_dir conf/touch_asu_config.json \
2929
--data_path $data/train.jsonl \
3030
--output_dir $dir \
3131
--pack_size 8192 \

west/bin/train.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
44

55
import logging
6+
import os
67
import pathlib
78
from dataclasses import dataclass, field
89
from typing import Any, Union
@@ -19,7 +20,7 @@
1920
@dataclass
2021
class TrainingArguments(TrainingArguments):
2122
optim: str = field(default="adafactor")
22-
model_config_path: str = field(default='')
23+
model_config_or_dir: str = field(default='')
2324

2425

2526
class MyTrainer(Trainer):
@@ -105,8 +106,11 @@ def main():
105106
)
106107
parser = HfArgumentParser((DataArguments, TrainingArguments))
107108
data_args, training_args = parser.parse_args_into_dataclasses()
108-
config = AutoConfig.from_pretrained(training_args.model_config_path)
109-
model = AutoModel.from_config(config)
109+
if os.path.isfile(training_args.model_config_or_dir): # init from config
110+
config = AutoConfig.from_pretrained(training_args.model_config_or_dir)
111+
model = AutoModel.from_config(config)
112+
else: # load from pretrained
113+
model = AutoModel.from_pretrained(training_args.model_config_or_dir)
110114
tokenizer = model.init_tokenizer()
111115
extractor = Extractor.get_class(model.model_type)(tokenizer)
112116

0 commit comments

Comments
 (0)