@@ -144,8 +144,7 @@ def get_ds(config: ModelConfig):
144
144
ds_raw = load_dataset ('opus_books' , f"{ config .lang_src } -{ config .lang_tgt } " , split = 'train' )
145
145
146
146
# Build tokenizers
147
- if config .local_rank == 0 :
148
- print ("Loading tokenizers..." )
147
+ print (f"GPU { config .local_rank } - Loading tokenizers..." )
149
148
tokenizer_src = get_or_build_tokenizer (config , ds_raw , config .lang_src )
150
149
tokenizer_tgt = get_or_build_tokenizer (config , ds_raw , config .lang_tgt )
151
150
@@ -167,9 +166,8 @@ def get_ds(config: ModelConfig):
167
166
max_len_src = max (max_len_src , len (src_ids ))
168
167
max_len_tgt = max (max_len_tgt , len (tgt_ids ))
169
168
170
- if config .local_rank == 0 :
171
- print (f'Max length of source sentence: { max_len_src } ' )
172
- print (f'Max length of target sentence: { max_len_tgt } ' )
169
+ print (f'GPU { config .local_rank } - Max length of source sentence: { max_len_src } ' )
170
+ print (f'GPU { config .local_rank } - Max length of target sentence: { max_len_tgt } ' )
173
171
174
172
175
173
train_dataloader = DataLoader (train_ds , batch_size = config .batch_size , shuffle = False , sampler = DistributedSampler (train_ds , shuffle = True ))
@@ -185,15 +183,13 @@ def train_model(config: ModelConfig):
185
183
# Define the device
186
184
assert torch .cuda .is_available (), "Training on CPU is not supported"
187
185
device = torch .device ("cuda" )
188
- if config .local_rank == 0 :
189
- print ("Using device:" , device )
186
+ print (f"GPU { config .local_rank } - Using device: { device } " )
190
187
191
188
# Make sure the weights folder exists
192
189
Path (config .model_folder ).mkdir (parents = True , exist_ok = True )
193
190
194
191
# Load the dataset
195
- if config .local_rank == 0 :
196
- print ("Loading dataset..." )
192
+ print (f"GPU { config .local_rank } - Loading dataset..." )
197
193
train_dataloader , val_dataloader , tokenizer_src , tokenizer_tgt = get_ds (config )
198
194
model = get_model (config , tokenizer_src .get_vocab_size (), tokenizer_tgt .get_vocab_size ()).to (device )
199
195
@@ -213,8 +209,7 @@ def train_model(config: ModelConfig):
213
209
model_filename = get_weights_file_path (config , int (config .preload ))
214
210
215
211
if model_filename is not None :
216
- if config .local_rank == 0 :
217
- print (f'Preloading model { model_filename } ' )
212
+ print (f'GPU { config .local_rank } - Preloading model { model_filename } ' )
218
213
state = torch .load (model_filename )
219
214
model .load_state_dict (state ['model_state_dict' ])
220
215
initial_epoch = state ['epoch' ] + 1
@@ -224,8 +219,7 @@ def train_model(config: ModelConfig):
224
219
del state
225
220
else :
226
221
# If we couldn't find a model to preload, just start from scratch
227
- if config .local_rank == 0 :
228
- print (f'Could not find model to preload: { config .preload } . Starting from scratch' )
222
+ print (f'GPU { config .local_rank } - Could not find model to preload: { config .preload } . Starting from scratch' )
229
223
230
224
# Only initialize W&B on the global rank 0 node
231
225
if config .global_rank == 0 :
@@ -255,12 +249,11 @@ def train_model(config: ModelConfig):
255
249
for epoch in range (initial_epoch , config .num_epochs ):
256
250
torch .cuda .empty_cache ()
257
251
model .train ()
258
- batch_iterator = tqdm (train_dataloader , desc = f"Processing Epoch { epoch :02d} on rank { config .global_rank } " )
259
- if config .local_rank != 0 :
260
- batch_iterator .disable = True
261
252
262
- for batch in batch_iterator :
253
+ # Disable tqdm on all nodes except the rank 0 GPU on each server
254
+ batch_iterator = tqdm (train_dataloader , desc = f"Processing Epoch { epoch :02d} on rank { config .global_rank } " , disable = config .local_rank != 0 )
263
255
256
+ for batch in batch_iterator :
264
257
encoder_input = batch ['encoder_input' ].to (device ) # (b, seq_len)
265
258
decoder_input = batch ['decoder_input' ].to (device ) # (B, seq_len)
266
259
encoder_mask = batch ['encoder_mask' ].to (device ) # (B, 1, 1, seq_len)
@@ -336,7 +329,7 @@ def train_model(config: ModelConfig):
336
329
assert config .local_rank != - 1 , "LOCAL_RANK environment variable not set"
337
330
assert config .global_rank != - 1 , "RANK environment variable not set"
338
331
339
- # Print configuration
332
+ # Print configuration (only once per server)
340
333
if config .local_rank == 0 :
341
334
print ("Configuration:" )
342
335
for key , value in config .__dict__ .items ():
0 commit comments