Load weights from checkpoint using lightning-cli #9555
Unanswered
man-sean
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 1 comment 1 reply
-
I think you should add a parameter to your model to do this. All parameters you define get exposed in the CLI. Something like: class MyModel(LightningModule):
def __init__(
self,
weights_from_checkpoint_path: Optional[str] = None,
):
if weights_from_checkpoint_path:
checkpoint = torch.load(weights_from_checkpoint_path)
self.load_state_dict(checkpoint['state_dict']) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hey 👋🏼
I trained a model using a regression loss and I want to use the learned weights as initialization for a GAN's generator (of the same type).
In regular training code I would call
gan.generator.load_from_checkpoint(PATH)
for that.But I started using Lightning CLI and I'm a bit confused how to replicate this behavior.
My GAN model is structured like that:
And my config file has the following structure:
I would expect
pl.LightningModule
to have an argumentcheckpoint_path
or something similar to performload_from_checkpoint
on its own inself.__init__
. It seems it does not have such functionality.How should I handle this staged training case using Lightning CLI?
Note: As I understand
trainer.resume_from_checkpoint
, it won't solve my problem, as it will not handle correctly the new structure of my model (and I don't want to resume the optimizers state etc.).Beta Was this translation helpful? Give feedback.
All reactions