@@ -222,13 +222,14 @@ def train_model(config: ModelConfig):
222
222
print (f'GPU { config .local_rank } - Could not find model to preload: { config .preload } . Starting from scratch' )
223
223
224
224
# Only initialize W&B on the global rank 0 node
225
- if config .global_rank == 0 :
225
+ if config .local_rank == 0 :
226
226
wandb .init (
227
227
# set the wandb project where this run will be logged
228
228
project = "pytorch-transformer-distributed" ,
229
229
# allow resuming existing run with the same name (in case the rank 0 node crashed)
230
230
id = wandb_run_id ,
231
231
resume = "allow" ,
232
+ group = config .wandb_group ,
232
233
# track hyperparameters and run metadata
233
234
config = config
234
235
)
@@ -271,7 +272,7 @@ def train_model(config: ModelConfig):
271
272
loss = loss_fn (proj_output .view (- 1 , tokenizer_tgt .get_vocab_size ()), label .view (- 1 ))
272
273
batch_iterator .set_postfix ({"loss" : f"{ loss .item ():6.3f} " , "global_step" : global_step })
273
274
274
- if config .global_rank == 0 :
275
+ if config .local_rank == 0 :
275
276
# Log the loss
276
277
wandb .log ({'train/loss' : loss .item (), 'global_step' : global_step })
277
278
@@ -321,6 +322,7 @@ def train_model(config: ModelConfig):
321
322
parser .add_argument ('--model_basename' , type = str , default = config .model_basename )
322
323
parser .add_argument ('--preload' , type = str , default = config .preload )
323
324
parser .add_argument ('--tokenizer_file' , type = str , default = config .tokenizer_file )
325
+ parser .add_argument ('--wandb_group' , type = str , default = "exp1" )
324
326
args = parser .parse_args ()
325
327
326
328
# Update default configuration with command line arguments
0 commit comments