Loading DeepSpeed sharded weights #17477
Unanswered
prabhuteja12
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello!
I'm saving a model that is being trained with Deepspeed stage 3 using the
ModelCheckpoint
callback. It saves multiple.pt
files in a folder. However, load those weights, I can't seem to understand how to load them in a sharded fashion ie my code currently callsconvert_zero_checkpoint_to_fp32_state_dict
usingfrom pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
to convert the checkpoint to a single model and loads it usingMyModule.load_from_checkpoint
and then pass toTrainer
to continue training. My understanding is that Deepspeed allows for loading those shards directly without going through this conversion framework. Can someone help me with how to do this?Beta Was this translation helpful? Give feedback.
All reactions