Description
Currently, when resuming from a previous run that utilizes a learning rate scheduler, we do NOT load a state dict from the scheduler.
But wait, does that mean our code is BROKEN?
Actually, b/c we save the state dict for the optimizer that initializes the learning rate scheduler AND we track the global steps taken which informs the learning rate scheduler where in it's schedule it is, the behavior is largely the same. However, if you were to inspect the learning rate scheduler state dict before and after, they will not be the same due to the parameter _step_count
not being updated.
Does this matter?
_step_count
looks like it's mostly there for debugging purposes. It's only actually used in one standard PyTorch learning rate scheduler: CosineAnnealingLR
here. I'm opening this issue b/c even though our training code works fine for 99% of use cases, we really should utilize the state dict to cover all the cases.
Goal
Update our recipes to save the lr_scheduler.state_dict()
to the intermediate state dict and upon resuming from checkpoint, we should call lr_scheduler.load_state_dict()
on the learning rate scheduler.