Skip to content

Commit fe73982

Browse files
authored
long on all ranks
1 parent 4ddbe9a commit fe73982

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,14 @@ def train_model(config: ModelConfig):
222222
print(f'GPU {config.local_rank} - Could not find model to preload: {config.preload}. Starting from scratch')
223223

224224
# Only initialize W&B on the global rank 0 node
225-
if config.global_rank == 0:
225+
if config.local_rank == 0:
226226
wandb.init(
227227
# set the wandb project where this run will be logged
228228
project="pytorch-transformer-distributed",
229229
# allow resuming existing run with the same name (in case the rank 0 node crashed)
230230
id=wandb_run_id,
231231
resume="allow",
232+
group=config.wandb_group,
232233
# track hyperparameters and run metadata
233234
config=config
234235
)
@@ -271,7 +272,7 @@ def train_model(config: ModelConfig):
271272
loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
272273
batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}", "global_step": global_step})
273274

274-
if config.global_rank == 0:
275+
if config.local_rank == 0:
275276
# Log the loss
276277
wandb.log({'train/loss': loss.item(), 'global_step': global_step})
277278

@@ -321,6 +322,7 @@ def train_model(config: ModelConfig):
321322
parser.add_argument('--model_basename', type=str, default=config.model_basename)
322323
parser.add_argument('--preload', type=str, default=config.preload)
323324
parser.add_argument('--tokenizer_file', type=str, default=config.tokenizer_file)
325+
parser.add_argument('--wandb_group', type=str, default="exp1")
324326
args = parser.parse_args()
325327

326328
# Update default configuration with command line arguments

0 commit comments

Comments
 (0)