Skip to content

Use LRScheduler.load_state_dict() #2730

Open
@joecummings

Description

@joecummings

Currently, when resuming from a previous run that utilizes a learning rate scheduler, we do NOT load a state dict from the scheduler.

But wait, does that mean our code is BROKEN?

Actually, b/c we save the state dict for the optimizer that initializes the learning rate scheduler AND we track the global steps taken which informs the learning rate scheduler where in it's schedule it is, the behavior is largely the same. However, if you were to inspect the learning rate scheduler state dict before and after, they will not be the same due to the parameter _step_count not being updated.

Does this matter?

_step_count looks like it's mostly there for debugging purposes. It's only actually used in one standard PyTorch learning rate scheduler: CosineAnnealingLR here. I'm opening this issue b/c even though our training code works fine for 99% of use cases, we really should utilize the state dict to cover all the cases.

Goal

Update our recipes to save the lr_scheduler.state_dict() to the intermediate state dict and upon resuming from checkpoint, we should call lr_scheduler.load_state_dict() on the learning rate scheduler.

Metadata

Metadata

Assignees

No one assigned

    Labels

    best practiceThings we should be doing but aren'tcommunity help wantedWe would love the community's help completing this issue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions