Skip to content

Commit cb413cc

Browse files
github-actions[bot]github-actions
and
github-actions
authored
[format] applied code formatting on changed files in pull request 3300 (#3302)
Co-authored-by: github-actions <[email protected]>
1 parent 31c78f2 commit cb413cc

File tree

1 file changed

+36
-29
lines changed

1 file changed

+36
-29
lines changed

applications/Chat/examples/train_prompts.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import argparse
2+
23
import pandas as pd
34
import torch
45
import torch.distributed as dist
5-
from coati.models.bloom import BLOOMActor, BLOOMRM, BLOOMCritic
6-
from coati.models.gpt import GPTActor, GPTRM, GPTCritic
7-
from coati.models.opt import OPTActor, OPTRM, OPTCritic
8-
from coati.models.llama import LlamaActor, LlamaRM, LlamaCritic
6+
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
7+
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
8+
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
9+
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
10+
from coati.models.opt import OPTRM, OPTActor, OPTCritic
911
from coati.trainer import PPOTrainer
1012
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
13+
from coati.utils import prepare_llama_tokenizer_and_embedding
1114
from torch.optim import Adam
1215
from torch.utils.data import DataLoader
1316
from torch.utils.data.distributed import DistributedSampler
14-
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, GPT2Tokenizer
15-
from coati.dataset import SupervisedDataset, DataCollatorForSupervisedDataset, PromptDataset
16-
from coati.utils import prepare_llama_tokenizer_and_embedding
17+
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
1718

1819
from colossalai.nn.optimizer import HybridAdam
1920

@@ -45,12 +46,12 @@ def main(args):
4546
initial_model = LlamaActor(pretrained=args.pretrain)
4647
else:
4748
raise ValueError(f'Unsupported actor model "{args.model}"')
48-
49+
4950
if args.rm_model == None:
5051
rm_model_name = args.model
5152
else:
5253
rm_model_name = args.rm_model
53-
54+
5455
if rm_model_name == 'gpt2':
5556
reward_model = GPTRM(pretrained=args.rm_pretrain)
5657
elif rm_model_name == 'bloom':
@@ -61,15 +62,14 @@ def main(args):
6162
reward_model = LlamaRM(pretrained=args.rm_pretrain)
6263
else:
6364
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
64-
65-
65+
6666
if args.rm_path is not None:
6767
reward_model.load_state_dict(state_dict)
68-
68+
6969
if args.strategy != 'colossalai_gemini':
70-
initial_model.to(torch.float16).to(torch.cuda.current_device())
71-
reward_model.to(torch.float16).to(torch.cuda.current_device())
72-
70+
initial_model.to(torch.float16).to(torch.cuda.current_device())
71+
reward_model.to(torch.float16).to(torch.cuda.current_device())
72+
7373
with strategy.model_init_context():
7474
if args.model == 'gpt2':
7575
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
@@ -81,7 +81,7 @@ def main(args):
8181
actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
8282
else:
8383
raise ValueError(f'Unsupported actor model "{args.model}"')
84-
84+
8585
if rm_model_name == 'gpt2':
8686
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
8787
elif rm_model_name == 'bloom':
@@ -92,11 +92,11 @@ def main(args):
9292
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
9393
else:
9494
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
95-
95+
9696
if args.rm_path is not None:
9797
critic.load_state_dict(state_dict)
9898
del state_dict
99-
99+
100100
if args.strategy != 'colossalai_gemini':
101101
critic.to(torch.float16).to(torch.cuda.current_device())
102102
actor.to(torch.float16).to(torch.cuda.current_device())
@@ -121,32 +121,38 @@ def main(args):
121121
tokenizer.eos_token = '<\s>'
122122
else:
123123
raise ValueError(f'Unsupported model "{args.model}"')
124-
124+
125125
if args.model == 'llama':
126126
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor)
127127
else:
128128
tokenizer.pad_token = tokenizer.eos_token
129-
129+
130130
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
131-
131+
132132
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_path, max_datasets_size=16384)
133133
if dist.is_initialized() and dist.get_world_size() > 1:
134134
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
135-
prompt_dataloader = DataLoader(prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size)
136-
135+
prompt_dataloader = DataLoader(prompt_dataset,
136+
shuffle=(prompt_sampler is None),
137+
sampler=prompt_sampler,
138+
batch_size=args.train_batch_size)
139+
137140
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384)
138141
if dist.is_initialized() and dist.get_world_size() > 1:
139142
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
140-
pretrain_dataloader = DataLoader(pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size, collate_fn=data_collator)
141-
143+
pretrain_dataloader = DataLoader(pretrain_dataset,
144+
shuffle=(pretrain_sampler is None),
145+
sampler=pretrain_sampler,
146+
batch_size=args.ptx_batch_size,
147+
collate_fn=data_collator)
148+
142149
def tokenize_fn(texts):
143150
# MUST padding to max length to ensure inputs of all ranks have the same length
144151
# Different length may lead to hang when using gemini, as different generation steps
145152
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
146153
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
147-
148-
(actor, actor_optim), (critic, critic_optim) = strategy.prepare(
149-
(actor, actor_optim), (critic, critic_optim))
154+
155+
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
150156

151157
# configure trainer
152158
trainer = PPOTrainer(
@@ -192,7 +198,8 @@ def tokenize_fn(texts):
192198
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
193199
parser.add_argument('--strategy',
194200
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
195-
default='naive', help='strategy to use')
201+
default='naive',
202+
help='strategy to use')
196203
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
197204
parser.add_argument('--pretrain', type=str, default=None)
198205
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])

0 commit comments

Comments
 (0)