1
1
import argparse
2
+
2
3
import pandas as pd
3
4
import torch
4
5
import torch .distributed as dist
5
- from coati .models .bloom import BLOOMActor , BLOOMRM , BLOOMCritic
6
- from coati .models .gpt import GPTActor , GPTRM , GPTCritic
7
- from coati .models .opt import OPTActor , OPTRM , OPTCritic
8
- from coati .models .llama import LlamaActor , LlamaRM , LlamaCritic
6
+ from coati .dataset import DataCollatorForSupervisedDataset , PromptDataset , SupervisedDataset
7
+ from coati .models .bloom import BLOOMRM , BLOOMActor , BLOOMCritic
8
+ from coati .models .gpt import GPTRM , GPTActor , GPTCritic
9
+ from coati .models .llama import LlamaActor , LlamaCritic , LlamaRM
10
+ from coati .models .opt import OPTRM , OPTActor , OPTCritic
9
11
from coati .trainer import PPOTrainer
10
12
from coati .trainer .strategies import ColossalAIStrategy , DDPStrategy , NaiveStrategy
13
+ from coati .utils import prepare_llama_tokenizer_and_embedding
11
14
from torch .optim import Adam
12
15
from torch .utils .data import DataLoader
13
16
from torch .utils .data .distributed import DistributedSampler
14
- from transformers import AutoTokenizer , BloomTokenizerFast , LlamaTokenizer , GPT2Tokenizer
15
- from coati .dataset import SupervisedDataset , DataCollatorForSupervisedDataset , PromptDataset
16
- from coati .utils import prepare_llama_tokenizer_and_embedding
17
+ from transformers import AutoTokenizer , BloomTokenizerFast , GPT2Tokenizer , LlamaTokenizer
17
18
18
19
from colossalai .nn .optimizer import HybridAdam
19
20
@@ -45,12 +46,12 @@ def main(args):
45
46
initial_model = LlamaActor (pretrained = args .pretrain )
46
47
else :
47
48
raise ValueError (f'Unsupported actor model "{ args .model } "' )
48
-
49
+
49
50
if args .rm_model == None :
50
51
rm_model_name = args .model
51
52
else :
52
53
rm_model_name = args .rm_model
53
-
54
+
54
55
if rm_model_name == 'gpt2' :
55
56
reward_model = GPTRM (pretrained = args .rm_pretrain )
56
57
elif rm_model_name == 'bloom' :
@@ -61,15 +62,14 @@ def main(args):
61
62
reward_model = LlamaRM (pretrained = args .rm_pretrain )
62
63
else :
63
64
raise ValueError (f'Unsupported reward model "{ rm_model_name } "' )
64
-
65
-
65
+
66
66
if args .rm_path is not None :
67
67
reward_model .load_state_dict (state_dict )
68
-
68
+
69
69
if args .strategy != 'colossalai_gemini' :
70
- initial_model .to (torch .float16 ).to (torch .cuda .current_device ())
71
- reward_model .to (torch .float16 ).to (torch .cuda .current_device ())
72
-
70
+ initial_model .to (torch .float16 ).to (torch .cuda .current_device ())
71
+ reward_model .to (torch .float16 ).to (torch .cuda .current_device ())
72
+
73
73
with strategy .model_init_context ():
74
74
if args .model == 'gpt2' :
75
75
actor = GPTActor (pretrained = args .pretrain , lora_rank = args .lora_rank )
@@ -81,7 +81,7 @@ def main(args):
81
81
actor = LlamaActor (pretrained = args .pretrain , lora_rank = args .lora_rank )
82
82
else :
83
83
raise ValueError (f'Unsupported actor model "{ args .model } "' )
84
-
84
+
85
85
if rm_model_name == 'gpt2' :
86
86
critic = GPTCritic (pretrained = args .rm_pretrain , lora_rank = args .lora_rank , use_action_mask = True )
87
87
elif rm_model_name == 'bloom' :
@@ -92,11 +92,11 @@ def main(args):
92
92
critic = LlamaCritic (pretrained = args .rm_pretrain , lora_rank = args .lora_rank , use_action_mask = True )
93
93
else :
94
94
raise ValueError (f'Unsupported reward model "{ rm_model_name } "' )
95
-
95
+
96
96
if args .rm_path is not None :
97
97
critic .load_state_dict (state_dict )
98
98
del state_dict
99
-
99
+
100
100
if args .strategy != 'colossalai_gemini' :
101
101
critic .to (torch .float16 ).to (torch .cuda .current_device ())
102
102
actor .to (torch .float16 ).to (torch .cuda .current_device ())
@@ -121,32 +121,38 @@ def main(args):
121
121
tokenizer .eos_token = '<\s>'
122
122
else :
123
123
raise ValueError (f'Unsupported model "{ args .model } "' )
124
-
124
+
125
125
if args .model == 'llama' :
126
126
tokenizer = prepare_llama_tokenizer_and_embedding (tokenizer , actor )
127
127
else :
128
128
tokenizer .pad_token = tokenizer .eos_token
129
-
129
+
130
130
data_collator = DataCollatorForSupervisedDataset (tokenizer = tokenizer )
131
-
131
+
132
132
prompt_dataset = PromptDataset (tokenizer = tokenizer , data_path = args .prompt_path , max_datasets_size = 16384 )
133
133
if dist .is_initialized () and dist .get_world_size () > 1 :
134
134
prompt_sampler = DistributedSampler (prompt_dataset , shuffle = True , seed = 42 , drop_last = True )
135
- prompt_dataloader = DataLoader (prompt_dataset , shuffle = (prompt_sampler is None ), sampler = prompt_sampler , batch_size = args .train_batch_size )
136
-
135
+ prompt_dataloader = DataLoader (prompt_dataset ,
136
+ shuffle = (prompt_sampler is None ),
137
+ sampler = prompt_sampler ,
138
+ batch_size = args .train_batch_size )
139
+
137
140
pretrain_dataset = SupervisedDataset (tokenizer = tokenizer , data_path = args .pretrain_dataset , max_datasets_size = 16384 )
138
141
if dist .is_initialized () and dist .get_world_size () > 1 :
139
142
pretrain_sampler = DistributedSampler (pretrain_dataset , shuffle = True , seed = 42 , drop_last = True )
140
- pretrain_dataloader = DataLoader (pretrain_dataset , shuffle = (pretrain_sampler is None ), sampler = pretrain_sampler , batch_size = args .ptx_batch_size , collate_fn = data_collator )
141
-
143
+ pretrain_dataloader = DataLoader (pretrain_dataset ,
144
+ shuffle = (pretrain_sampler is None ),
145
+ sampler = pretrain_sampler ,
146
+ batch_size = args .ptx_batch_size ,
147
+ collate_fn = data_collator )
148
+
142
149
def tokenize_fn (texts ):
143
150
# MUST padding to max length to ensure inputs of all ranks have the same length
144
151
# Different length may lead to hang when using gemini, as different generation steps
145
152
batch = tokenizer (texts , return_tensors = 'pt' , max_length = 96 , padding = 'max_length' , truncation = True )
146
153
return {k : v .to (torch .cuda .current_device ()) for k , v in batch .items ()}
147
-
148
- (actor , actor_optim ), (critic , critic_optim ) = strategy .prepare (
149
- (actor , actor_optim ), (critic , critic_optim ))
154
+
155
+ (actor , actor_optim ), (critic , critic_optim ) = strategy .prepare ((actor , actor_optim ), (critic , critic_optim ))
150
156
151
157
# configure trainer
152
158
trainer = PPOTrainer (
@@ -192,7 +198,8 @@ def tokenize_fn(texts):
192
198
parser .add_argument ('--pretrain_dataset' , type = str , default = None , help = 'path to the pretrained dataset' )
193
199
parser .add_argument ('--strategy' ,
194
200
choices = ['naive' , 'ddp' , 'colossalai_gemini' , 'colossalai_zero2' ],
195
- default = 'naive' , help = 'strategy to use' )
201
+ default = 'naive' ,
202
+ help = 'strategy to use' )
196
203
parser .add_argument ('--model' , default = 'gpt2' , choices = ['gpt2' , 'bloom' , 'opt' , 'llama' ])
197
204
parser .add_argument ('--pretrain' , type = str , default = None )
198
205
parser .add_argument ('--rm_model' , default = None , choices = ['gpt2' , 'bloom' , 'opt' , 'llama' ])
0 commit comments