Skip to content

Commit 40c19f3

Browse files
authored
Add wandb project name argument and allow change wandb run name
1 parent 6f80214 commit 40c19f3

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

train.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,8 @@
388388
help='use the multi-epochs-loader to save time at the beginning of every epoch')
389389
group.add_argument('--log-wandb', action='store_true', default=False,
390390
help='log training and validation metrics to wandb')
391+
group.add_argument('--wandb-project', default=None, type=str,
392+
help='wandb project name')
391393
group.add_argument('--wandb-tags', default=[], type=str, nargs='+',
392394
help='wandb tags')
393395
group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID',
@@ -823,9 +825,14 @@ def main():
823825
if utils.is_primary(args) and args.log_wandb:
824826
if has_wandb:
825827
assert not args.wandb_resume_id or args.resume
826-
wandb.init(project=args.experiment, config=args, tags=args.wandb_tags,
827-
resume='must' if args.wandb_resume_id else None,
828-
id=args.wandb_resume_id if args.wandb_resume_id else None)
828+
wandb.init(
829+
project=args.wandb_project,
830+
name=args.experiment,
831+
config=args,
832+
tags=args.wandb_tags,
833+
resume="must" if args.wandb_resume_id else None,
834+
id=args.wandb_resume_id if args.wandb_resume_id else None,
835+
)
829836
else:
830837
_logger.warning(
831838
"You've requested to log metrics to wandb but package not found. "

0 commit comments

Comments
 (0)