Skip to content

Commit 5caa735

Browse files
committed
[example] add qa recipe
1 parent 7f1a7aa commit 5caa735

File tree

8 files changed

+174
-11
lines changed

8 files changed

+174
-11
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"dispatch_batches": false
3+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
{
2+
"fp16": {
3+
"enabled": "auto",
4+
"loss_scale": 0,
5+
"loss_scale_window": 1000,
6+
"initial_scale_power": 16,
7+
"hysteresis": 2,
8+
"min_loss_scale": 1
9+
},
10+
"bf16": {
11+
"enabled": "auto"
12+
},
13+
"optimizer": {
14+
"type": "AdamW",
15+
"params": {
16+
"lr": "auto",
17+
"betas": "auto",
18+
"eps": "auto",
19+
"weight_decay": "auto"
20+
}
21+
},
22+
23+
"scheduler": {
24+
"type": "WarmupLR",
25+
"params": {
26+
"warmup_min_lr": "auto",
27+
"warmup_max_lr": "auto",
28+
"warmup_num_steps": "auto"
29+
}
30+
},
31+
32+
"zero_optimization": {
33+
"stage": 2,
34+
"offload_optimizer": {
35+
"device": "none",
36+
"pin_memory": true
37+
},
38+
"offload_param": {
39+
"device": "none",
40+
"pin_memory": true
41+
},
42+
"overlap_comm": true,
43+
"contiguous_gradients": true,
44+
"sub_group_size": 1e9,
45+
"reduce_bucket_size": "auto"
46+
},
47+
48+
"gradient_accumulation_steps": "auto",
49+
"gradient_clipping": "auto",
50+
"steps_per_print": 100,
51+
"train_batch_size": "auto",
52+
"train_micro_batch_size_per_gpu": "auto",
53+
"wall_clock_breakdown": false
54+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"encoder_ds_rate": 4,
3+
"encoder_projector_ds_rate": 2,
4+
"llm_model_name_or_path": "/bucket/output/jfs-hdfs/user/binbin.zhang/huggingface/hub/Qwen2-1.5B-Instruct",
5+
"lora_config": null,
6+
"model_type": "touch_asu",
7+
"projector_hidden_size": 2048,
8+
"transformers_version": "4.52.3",
9+
"wenet_model_name_or_path": "/bucket/output/jfs-hdfs/user/binbin.zhang/models/wenet/wenetspeech/u2pp_conformer/"
10+
}

examples/belle_1.4M_qa/run.sh

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 Binbin Zhang([email protected])
2+
3+
[ ! -s west ] && ln -s ../../../west
4+
[ ! -s tools ] && ln -s ../../../tools
5+
export PYTHONPATH=$PYTHONPATH:$PWD
6+
7+
export CUDA_VISIBLE_DEVICES="1" # Change this to all your available gpus, such as "0,1,2,3"
8+
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F ',' '{print NF}')
9+
10+
model_config_or_dir=pretrain_qwen1.7b_aishell_asr
11+
12+
stage=decode # data/train/decode
13+
data=data
14+
15+
steps=10000 # training steps
16+
pack_size=8192
17+
lr_rate=5e-5
18+
dir=exp/Qwe3-1.7B-Instruct-firered-${pack_size}-${lr_rate}
19+
20+
# Note: Change your model settings in `conf/touch_asu_config.json`
21+
22+
23+
if [ $stage == "data" ] || [ $stage == "all" ]; then
24+
echo "Prepare required data"
25+
# TODO:
26+
mkdir $data
27+
cp -r /jfs-hdfs/user/Archive/AQA/qa_test/chinese_qa.jsonl $data
28+
fi
29+
30+
31+
if [ $stage == "train" ] || [ $stage == "all" ]; then
32+
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus west/bin/train.py \
33+
--model_config_or_dir $model_config_or_dir \
34+
--data_path $data/data_tn_cn_messages_aishell.list \
35+
--output_dir $dir \
36+
--pack_size $pack_size \
37+
--bf16 True \
38+
--max_steps $steps \
39+
--num_data_cycles 100 \
40+
--per_device_train_batch_size 1 \
41+
--per_device_eval_batch_size 1 \
42+
--gradient_accumulation_steps 1 \
43+
--save_strategy "steps" \
44+
--save_steps 100 \
45+
--save_total_limit 100 \
46+
--learning_rate $lr_rate \
47+
--weight_decay 0.01 \
48+
--adam_beta2 0.95 \
49+
--warmup_ratio 0.5 \
50+
--lr_scheduler_type "cosine" \
51+
--logging_steps 1 \
52+
--report_to "tensorboard" \
53+
--gradient_checkpointing \
54+
--dataloader_num_workers 2 \
55+
--dataloader_prefetch_factor 10 \
56+
--save_total_limit 10000 \
57+
--deepspeed conf/ds_config_zero2.json \
58+
--accelerator_config conf/accelerator_config.json
59+
fi
60+
61+
62+
if [ $stage == "decode" ] || [ $stage == "all" ]; then
63+
mdir=$dir/checkpoint-${steps}
64+
python west/bin/decode.py \
65+
--data_path $data/chinese_qa.jsonl \
66+
--model_dir $mdir \
67+
--result_path $mdir/result.txt
68+
python tools/get_qa_hyp_ref_text.py $data/chinese_qa_messages.jsonl \
69+
$mdir/result.txt $mdir/result.json
70+
python tools/compute-acc-of-contain.py $mdir/result.json
71+
fi

examples/belle_1.4M_qa/tools

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../tools

examples/belle_1.4M_qa/west

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../west

west/dataset/dataset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def _read_one(self):
131131
try:
132132
x['txt'] = x['txt'].decode('utf8')
133133
x['wav'] = io.BytesIO(x['wav'])
134+
if "messages" in x.keys():
135+
x['messages'] = json.loads(
136+
x['messages'].decode('utf8'))
134137
yield x
135138
except Exception:
136139
logging.info(f'Dataset decode error, {line}')
@@ -244,9 +247,9 @@ def __iter__(self) -> Dict[str, torch.Tensor]:
244247
print(tokenizer.bos_token_id)
245248
data_args = DataArguments
246249
data_args.data_path = 'data/train.jsonl'
247-
data_args.extractor_type = 'tts_codec'
248-
dataset = SpeechDataset(tokenizer, data_args)
250+
data_args.extractor_type = 'touch_asu'
251+
extractor = Extractor.get_class(data_args.extractor_type)(tokenizer)
252+
dataset = SpeechDataset(extractor, data_args)
249253
for i, x in enumerate(dataset):
250-
print(x)
251254
if i > 0:
252255
break

west/models/touch_asu/extractor_touch_asu.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,31 @@ class ExtractorTouchASU(Extractor):
1414
fields_pack_offset = {'audio_offsets'}
1515

1616
def extract(self, item):
17+
"""
18+
1. speech pretraining data (asr):
19+
messages = [
20+
{'role': 'user', 'content': [{
21+
'type': 'text', 'text': 'Transcribe the Speech'}]},
22+
{'role': 'assistant', 'content': item['txt']},
23+
]
24+
2. QA: SFT data (multi-turn)
25+
messages = [
26+
{'role': 'system', 'content': 'You are a helpful assistant.'}, # optional # noqa
27+
{'role': 'user', 'content': 'What is the capital of China?'}, # optional # noqa
28+
{'role': 'assistant', 'content': 'The capital of China is Beijing.'}, # optional # noqa
29+
{'role': 'user', 'content': {'type': 'audio', 'audio': item['wav']}}, # last turn # noqa
30+
{'role': 'assistant', 'content': item['txt']},
31+
]
32+
"""
1733
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
18-
if 'messages' in item: # OpenAI role-content based SFT data
34+
# OpenAI role-content based SFT data
35+
# At least one pair of "user" and "assistant"
36+
if 'messages' in item and len(item["messages"]) >= 2:
1937
messages = item['messages']
2038
else: # Speech pretraining data
2139
messages = [
2240
{
23-
'role':
24-
'user',
41+
'role': 'user',
2542
'content': [{
2643
'type': 'text',
2744
'text': 'Transcribe the Speech'
@@ -36,13 +53,16 @@ def extract(self, item):
3653
},
3754
]
3855

39-
t0 = '<|im_start|>user\n'
56+
t0 = ''
4057
t1 = '<|audio_eos|><|im_end|>\n' + '<|im_start|>assistant\n'
4158
t2 = ''
42-
for msg in messages:
43-
if msg['role'] == 'system':
44-
t0 += msg['content']
45-
elif msg['role'] == 'user':
59+
# multi-turn
60+
for msg in messages[:-2]:
61+
t0 += '<|im_start|>' + msg['role'] + '\n' + \
62+
msg['content'] + '<|im_end|>\n'
63+
for msg in messages[-2:]:
64+
if msg['role'] == 'user':
65+
t0 += '<|im_start|>user\n'
4666
if isinstance(msg['content'], dict):
4767
assert msg['content']['type'] == 'audio'
4868
t0 += '<|audio_bos|>'

0 commit comments

Comments
 (0)