DDP training and storing rank specific info in checkpoints #21097
Unanswered
bardsleypt
asked this question in
DDP / multi-GPU / multi-node
Replies: 1 comment
-
I believe I have at least narrowed down why my approach is not working:
This at least explains why I only see the rank-0 information in the saved checkpoint. So my question can now be reduced to: Is there any way to synchronize and otherwise send the checkpoint dictionaries for each rank to the global-0 process? As a workaround, I can do some pretty hacky temporary-save and load routines in the |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm working on preserving state between start/stop of training runs in a manner that guarantees reproducible results. That is, I'd like to be able to stop my training at any given checkpoint, then restart the training from that checkpoint and finish to completion, and have these results match (exactly) the results obtained from a single continuous run. I've been able to do this on single node setups by storing the outputs of
torch.get_rng_state()
torch.cuda.get_rng_state()
np.random.get_state()
random.getstate()
within the model checkpoint, and using the corresponding
set
method upon loading the checkpoint. I've been performing the save/load routines within a custompytorch_lightning.callbacks.Callback
by overriding theon_save_checkpoint
andon_load_checkpoint
appropriately.I'm now trying to perform the same checkpoint save/load procedure using a multi-node setup, with a DDP strategy. My attempt was to append the global-rank-specific rng states to the checkpoint dictionary, which I had thought would then be saved appropriately. However, when I executed the code, the only rng state that is preserved within the checkpoint dictionary, is the rank 0 state. Can someone please advise on how to preserve the rng states from other ranks within the checkpoint in a DDP setup? As a higher level question: if there is a better way to preserve these states between training runs rather than checkpoint storage and re-instantiation, that information would also be welcome.
The main Callback save routine I'm using is posted below. I've then been checking the contents of the saved checkpoint dictionary by using a manual
torch.load()
call.python version: 3.9.12
pytorch version: 2.2.0+cu121
pytorch_lightning version: 2.2.0
Beta Was this translation helpful? Give feedback.
All reactions