3
3
from torch .utils .data import DataLoader , random_split
4
4
5
5
# Distributed training
6
- import torch .multiprocessing as mp
7
6
from torch .utils .data .distributed import DistributedSampler
8
- from torch .nn .parallel import DistributedDataParallel as DDP
7
+ from torch .nn .parallel import DistributedDataParallel
9
8
from torch .distributed import init_process_group , destroy_process_group
10
9
11
10
import warnings
26
25
27
26
from model import build_transformer
28
27
from dataset import BilingualDataset , causal_mask
29
- from config import get_config , get_weights_file_path , get_latest_weights_file_path
28
+ from config import get_default_config , get_weights_file_path , get_latest_weights_file_path
30
29
31
30
def greedy_decode (model , source , source_mask , tokenizer_src , tokenizer_tgt , max_len , device ):
32
31
sos_idx = tokenizer_tgt .token_to_id ('[SOS]' )
@@ -128,7 +127,7 @@ def get_all_sentences(ds, lang):
128
127
yield item ['translation' ][lang ]
129
128
130
129
def get_or_build_tokenizer (config , ds , lang ):
131
- tokenizer_path = Path (config [ ' tokenizer_file' ] .format (lang ))
130
+ tokenizer_path = Path (config . tokenizer_file .format (lang ))
132
131
if not Path .exists (tokenizer_path ):
133
132
# Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
134
133
tokenizer = Tokenizer (WordLevel (unk_token = "[UNK]" ))
@@ -142,78 +141,79 @@ def get_or_build_tokenizer(config, ds, lang):
142
141
143
142
def get_ds (config ):
144
143
# It only has the train split, so we divide it overselves
145
- ds_raw = load_dataset ('opus_books' , f"{ config [ ' lang_src' ] } -{ config [ ' lang_tgt' ] } " , split = 'train' )
144
+ ds_raw = load_dataset ('opus_books' , f"{ config . lang_src } -{ config . lang_tgt } " , split = 'train' )
146
145
147
146
# Build tokenizers
148
- if config [ ' local_rank' ] == 0 :
147
+ if config . local_rank == 0 :
149
148
print ("Loading tokenizers..." )
150
- tokenizer_src = get_or_build_tokenizer (config , ds_raw , config [ ' lang_src' ] )
151
- tokenizer_tgt = get_or_build_tokenizer (config , ds_raw , config [ ' lang_tgt' ] )
149
+ tokenizer_src = get_or_build_tokenizer (config , ds_raw , config . lang_src )
150
+ tokenizer_tgt = get_or_build_tokenizer (config , ds_raw , config . lang_tgt )
152
151
153
152
# Keep 90% for training, 10% for validation
154
153
train_ds_size = int (0.9 * len (ds_raw ))
155
154
val_ds_size = len (ds_raw ) - train_ds_size
156
155
train_ds_raw , val_ds_raw = random_split (ds_raw , [train_ds_size , val_ds_size ])
157
156
158
- train_ds = BilingualDataset (train_ds_raw , tokenizer_src , tokenizer_tgt , config [ ' lang_src' ] , config [ ' lang_tgt' ] , config [ ' seq_len' ] )
159
- val_ds = BilingualDataset (val_ds_raw , tokenizer_src , tokenizer_tgt , config [ ' lang_src' ] , config [ ' lang_tgt' ] , config [ ' seq_len' ] )
157
+ train_ds = BilingualDataset (train_ds_raw , tokenizer_src , tokenizer_tgt , config . lang_src , config . lang_tgt , config . seq_len )
158
+ val_ds = BilingualDataset (val_ds_raw , tokenizer_src , tokenizer_tgt , config . lang_src , config . lang_tgt , config . seq_len )
160
159
161
160
# Find the maximum length of each sentence in the source and target sentence
162
161
max_len_src = 0
163
162
max_len_tgt = 0
164
163
165
164
for item in ds_raw :
166
- src_ids = tokenizer_src .encode (item ['translation' ][config [ ' lang_src' ] ]).ids
167
- tgt_ids = tokenizer_tgt .encode (item ['translation' ][config [ ' lang_tgt' ] ]).ids
165
+ src_ids = tokenizer_src .encode (item ['translation' ][config . lang_src ]).ids
166
+ tgt_ids = tokenizer_tgt .encode (item ['translation' ][config . lang_tgt ]).ids
168
167
max_len_src = max (max_len_src , len (src_ids ))
169
168
max_len_tgt = max (max_len_tgt , len (tgt_ids ))
170
169
171
- if config [ ' local_rank' ] == 0 :
170
+ if config . local_rank == 0 :
172
171
print (f'Max length of source sentence: { max_len_src } ' )
173
172
print (f'Max length of target sentence: { max_len_tgt } ' )
174
173
175
174
176
- train_dataloader = DataLoader (train_ds , batch_size = config [ ' batch_size' ] , shuffle = False , sampler = DistributedSampler (train_ds , shuffle = True ))
175
+ train_dataloader = DataLoader (train_ds , batch_size = config . batch_size , shuffle = False , sampler = DistributedSampler (train_ds , shuffle = True ))
177
176
val_dataloader = DataLoader (val_ds , batch_size = 1 , shuffle = True )
178
177
179
178
return train_dataloader , val_dataloader , tokenizer_src , tokenizer_tgt
180
179
181
180
def get_model (config , vocab_src_len , vocab_tgt_len ):
182
- model = build_transformer (vocab_src_len , vocab_tgt_len , config [ " seq_len" ] , config [ ' seq_len' ] , d_model = config [ ' d_model' ] )
181
+ model = build_transformer (vocab_src_len , vocab_tgt_len , config . seq_len , config . seq_len , d_model = config . d_model )
183
182
return model
184
183
185
184
def train_model (config ):
186
185
# Define the device
187
186
assert torch .cuda .is_available (), "Training on CPU is not supported"
188
187
device = torch .device ("cuda" )
189
- if config [ ' local_rank' ] == 0 :
188
+ if config . local_rank == 0 :
190
189
print ("Using device:" , device )
191
190
192
191
# Make sure the weights folder exists
193
- Path (config [ ' model_folder' ] ).mkdir (parents = True , exist_ok = True )
192
+ Path (config . model_folder ).mkdir (parents = True , exist_ok = True )
194
193
195
194
# Load the dataset
196
- if config [ ' local_rank' ] == 0 :
195
+ if config . local_rank == 0 :
197
196
print ("Loading dataset..." )
198
197
train_dataloader , val_dataloader , tokenizer_src , tokenizer_tgt = get_ds (config )
199
198
model = get_model (config , tokenizer_src .get_vocab_size (), tokenizer_tgt .get_vocab_size ()).to (device )
200
199
201
200
optimizer = torch .optim .Adam (model .parameters (), lr = config ['lr' ], eps = 1e-9 )
202
201
203
- # If the user specified a model to preload before training , load it
202
+ # By default , load the latest checkpoint
204
203
initial_epoch = 0
205
204
global_step = 0
206
205
wandb_run_id = None
207
- if config [ ' preload' ] != '' :
206
+ if config . preload != '' :
208
207
209
- if config ['preload' ] == 'latest' :
208
+ if config .preload == 'latest' :
209
+ # Get the filename of the latest checkpoint
210
210
model_filename = get_latest_weights_file_path (config )
211
211
else :
212
- model_filename = get_weights_file_path (config , int (config ['preload' ]))
212
+ # In case we want to preload a specific checkpoint
213
+ model_filename = get_weights_file_path (config , int (config .preload ))
213
214
214
- # If we couldn't find a model to preload, just start from scratch
215
215
if model_filename is not None :
216
- if config [ ' local_rank' ] == 0 :
216
+ if config . local_rank == 0 :
217
217
print (f'Preloading model { model_filename } ' )
218
218
state = torch .load (model_filename )
219
219
model .load_state_dict (state ['model_state_dict' ])
@@ -223,11 +223,12 @@ def train_model(config):
223
223
wandb_run_id = state ['wandb_run_id' ]
224
224
del state
225
225
else :
226
- if config ['local_rank' ] == 0 :
227
- print (f'Could not find model to preload: { config ["preload" ]} . Starting from scratch' )
226
+ # If we couldn't find a model to preload, just start from scratch
227
+ if config .local_rank == 0 :
228
+ print (f'Could not find model to preload: { config .preload } . Starting from scratch' )
228
229
229
- # Only initialize W&B on the rank 0 node
230
- if config [ ' global_rank' ] == 0 :
230
+ # Only initialize W&B on the global rank 0 node
231
+ if config . global_rank == 0 :
231
232
wandb .init (
232
233
# set the wandb project where this run will be logged
233
234
project = "pytorch-transformer-distributed" ,
@@ -240,22 +241,22 @@ def train_model(config):
240
241
241
242
# Convert the model to DistributedDataParallel
242
243
# Here we can also specify the bucket_cap_mb parameter to control the size of the buckets
243
- model = DDP (model , device_ids = [config [ ' local_rank' ] ])
244
+ model = DistributedDataParallel (model , device_ids = [config . local_rank ])
244
245
245
246
loss_fn = nn .CrossEntropyLoss (ignore_index = tokenizer_src .token_to_id ('[PAD]' ), label_smoothing = 0.1 ).to (device )
246
247
247
- if config [ ' global_rank' ] == 0 :
248
+ if config . global_rank == 0 :
248
249
# define our custom x axis metric
249
250
wandb .define_metric ("global_step" )
250
251
# define which metrics will be plotted against it
251
252
wandb .define_metric ("validation/*" , step_metric = "global_step" )
252
253
wandb .define_metric ("train/*" , step_metric = "global_step" )
253
254
254
- for epoch in range (initial_epoch , config [ ' num_epochs' ] ):
255
+ for epoch in range (initial_epoch , config . num_epochs ):
255
256
torch .cuda .empty_cache ()
256
257
model .train ()
257
- batch_iterator = tqdm (train_dataloader , desc = f"Processing Epoch { epoch :02d} on rank { config [ ' global_rank' ] } " )
258
- if config [ ' local_rank' ] != 0 :
258
+ batch_iterator = tqdm (train_dataloader , desc = f"Processing Epoch { epoch :02d} on rank { config . global_rank } " )
259
+ if config . local_rank != 0 :
259
260
batch_iterator .disable = True
260
261
261
262
for batch in batch_iterator :
@@ -277,7 +278,7 @@ def train_model(config):
277
278
loss = loss_fn (proj_output .view (- 1 , tokenizer_tgt .get_vocab_size ()), label .view (- 1 ))
278
279
batch_iterator .set_postfix ({"loss" : f"{ loss .item ():6.3f} " , "global_step" : global_step })
279
280
280
- if config [ ' global_rank' ] == 0 :
281
+ if config . global_rank == 0 :
281
282
# Log the loss
282
283
wandb .log ({'train/loss' : loss .item (), 'global_step' : global_step })
283
284
@@ -291,9 +292,9 @@ def train_model(config):
291
292
global_step += 1
292
293
293
294
# Only run validation and checkpoint saving on the rank 0 node
294
- if config [ ' global_rank' ] == 0 :
295
+ if config . global_rank == 0 :
295
296
# Run validation at the end of every epoch
296
- run_validation (model , val_dataloader , tokenizer_src , tokenizer_tgt , config [ ' seq_len' ] , device , lambda msg : batch_iterator .write (msg ), global_step )
297
+ run_validation (model , val_dataloader , tokenizer_src , tokenizer_tgt , config . seq_len , device , lambda msg : batch_iterator .write (msg ), global_step )
297
298
298
299
# Save the model at the end of every epoch
299
300
model_filename = get_weights_file_path (config , epoch )
@@ -308,39 +309,42 @@ def train_model(config):
308
309
309
310
if __name__ == '__main__' :
310
311
warnings .filterwarnings ("ignore" )
311
- config = get_config ()
312
+ config = get_default_config ()
312
313
313
314
# Read command line arguments and overwrite config accordingly
314
315
parser = argparse .ArgumentParser ()
315
- parser .add_argument ('--batch_size' , type = int , default = config [ ' batch_size' ] )
316
- parser .add_argument ('--num_epochs' , type = int , default = config [ ' num_epochs' ] )
317
- parser .add_argument ('--lr' , type = float , default = config [ 'lr' ] )
318
- parser .add_argument ('--seq_len' , type = int , default = config [ ' seq_len' ] )
319
- parser .add_argument ('--d_model' , type = int , default = config [ ' d_model' ] )
320
- parser .add_argument ('--lang_src' , type = str , default = config [ ' lang_src' ] )
321
- parser .add_argument ('--lang_tgt' , type = str , default = config [ ' lang_tgt' ] )
322
- parser .add_argument ('--model_folder' , type = str , default = config [ ' model_folder' ] )
323
- parser .add_argument ('--model_basename' , type = str , default = config [ ' model_basename' ] )
324
- parser .add_argument ('--preload' , type = str , default = config [ ' preload' ] )
325
- parser .add_argument ('--tokenizer_file' , type = str , default = config [ ' tokenizer_file' ] )
316
+ parser .add_argument ('--batch_size' , type = int , default = config . batch_size )
317
+ parser .add_argument ('--num_epochs' , type = int , default = config . num_epochs )
318
+ parser .add_argument ('--lr' , type = float , default = config . lr )
319
+ parser .add_argument ('--seq_len' , type = int , default = config . seq_len )
320
+ parser .add_argument ('--d_model' , type = int , default = config . d_model )
321
+ parser .add_argument ('--lang_src' , type = str , default = config . lang_src )
322
+ parser .add_argument ('--lang_tgt' , type = str , default = config . lang_tgt )
323
+ parser .add_argument ('--model_folder' , type = str , default = config . model_folder )
324
+ parser .add_argument ('--model_basename' , type = str , default = config . model_basename )
325
+ parser .add_argument ('--preload' , type = str , default = config . preload )
326
+ parser .add_argument ('--tokenizer_file' , type = str , default = config . tokenizer_file )
326
327
327
328
# Update default configuration with command line arguments
328
329
args = parser .parse_args ()
329
- config .update (vars (args ))
330
+ config .__dict__ . update (vars (args ))
330
331
331
332
# Add local rank and global rank to the config
332
- config ['local_rank' ] = int (os .environ ['LOCAL_RANK' ])
333
- config ['global_rank' ] = int (os .environ ['RANK' ])
333
+ config .local_rank = int (os .environ ['LOCAL_RANK' ])
334
+ config .global_rank = int (os .environ ['RANK' ])
335
+
336
+ assert config .local_rank != - 1 , "LOCAL_RANK environment variable not set"
337
+ assert config .global_rank != - 1 , "RANK environment variable not set"
334
338
335
339
# Print configuration
336
- if config [ ' local_rank' ] == 0 :
340
+ if config . local_rank == 0 :
337
341
print ("Configuration:" )
338
- for key , value in config .items ():
342
+ for key , value in config .__dict__ . items ():
339
343
print (f"{ key :>20} : { value } " )
340
344
341
345
# Setup distributed training
342
346
init_process_group (backend = 'nccl' )
343
- torch .cuda .set_device (config [ ' local_rank' ] )
347
+ torch .cuda .set_device (config . local_rank )
344
348
345
349
# Train the model
346
350
train_model (config )
0 commit comments