|
| 1 | +# Copyright 2025 Binbin Zhang([email protected]) |
| 2 | + |
| 3 | +[ ! -s west ] && ln -s ../../../west |
| 4 | +[ ! -s tools ] && ln -s ../../../tools |
| 5 | +export PYTHONPATH=$PYTHONPATH:$PWD |
| 6 | + |
| 7 | +export CUDA_VISIBLE_DEVICES="1" # Change this to all your available gpus, such as "0,1,2,3" |
| 8 | +num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F ',' '{print NF}') |
| 9 | + |
| 10 | +model_config_or_dir=pretrain_qwen1.7b_aishell_asr |
| 11 | + |
| 12 | +stage=decode # data/train/decode |
| 13 | +data=data |
| 14 | + |
| 15 | +steps=10000 # training steps |
| 16 | +pack_size=8192 |
| 17 | +lr_rate=5e-5 |
| 18 | +dir=exp/Qwe3-1.7B-Instruct-firered-${pack_size}-${lr_rate} |
| 19 | + |
| 20 | +# Note: Change your model settings in `conf/touch_asu_config.json` |
| 21 | + |
| 22 | + |
| 23 | +if [ $stage == "data" ] || [ $stage == "all" ]; then |
| 24 | + echo "Prepare required data" |
| 25 | + # TODO: |
| 26 | + mkdir $data |
| 27 | + cp -r /jfs-hdfs/user/Archive/AQA/qa_test/chinese_qa.jsonl $data |
| 28 | +fi |
| 29 | + |
| 30 | + |
| 31 | +if [ $stage == "train" ] || [ $stage == "all" ]; then |
| 32 | + torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus west/bin/train.py \ |
| 33 | + --model_config_or_dir $model_config_or_dir \ |
| 34 | + --data_path $data/data_tn_cn_messages_aishell.list \ |
| 35 | + --output_dir $dir \ |
| 36 | + --pack_size $pack_size \ |
| 37 | + --bf16 True \ |
| 38 | + --max_steps $steps \ |
| 39 | + --num_data_cycles 100 \ |
| 40 | + --per_device_train_batch_size 1 \ |
| 41 | + --per_device_eval_batch_size 1 \ |
| 42 | + --gradient_accumulation_steps 1 \ |
| 43 | + --save_strategy "steps" \ |
| 44 | + --save_steps 100 \ |
| 45 | + --save_total_limit 100 \ |
| 46 | + --learning_rate $lr_rate \ |
| 47 | + --weight_decay 0.01 \ |
| 48 | + --adam_beta2 0.95 \ |
| 49 | + --warmup_ratio 0.5 \ |
| 50 | + --lr_scheduler_type "cosine" \ |
| 51 | + --logging_steps 1 \ |
| 52 | + --report_to "tensorboard" \ |
| 53 | + --gradient_checkpointing \ |
| 54 | + --dataloader_num_workers 2 \ |
| 55 | + --dataloader_prefetch_factor 10 \ |
| 56 | + --save_total_limit 10000 \ |
| 57 | + --deepspeed conf/ds_config_zero2.json \ |
| 58 | + --accelerator_config conf/accelerator_config.json |
| 59 | +fi |
| 60 | + |
| 61 | + |
| 62 | +if [ $stage == "decode" ] || [ $stage == "all" ]; then |
| 63 | + mdir=$dir/checkpoint-${steps} |
| 64 | + python west/bin/decode.py \ |
| 65 | + --data_path $data/chinese_qa.jsonl \ |
| 66 | + --model_dir $mdir \ |
| 67 | + --result_path $mdir/result.txt |
| 68 | + python tools/get_qa_hyp_ref_text.py $data/chinese_qa_messages.jsonl \ |
| 69 | + $mdir/result.txt $mdir/result.json |
| 70 | + python tools/compute-acc-of-contain.py $mdir/result.json |
| 71 | +fi |
0 commit comments