Possible to change starting epoch? #17396
-
Hey everyone! So a project I've been helping out a bit on, https://github.com/34j/so-vits-svc-fork , is manually loading model checkpoints upon training start / initialization. Currently, it uses a bit of a hacky way to set the current epoch to continue from: def on_train_start(self) -> None:
if not self.tuning:
self.set_current_epoch(self._temp_epoch) # This is being loaded from the model
total_batch_idx = self.current_epoch * len(self.trainer.train_dataloader)
self.set_total_batch_idx(total_batch_idx)
global_step = total_batch_idx * self.optimizers_count
self.set_global_step(global_step)
def set_current_epoch(self, epoch: int):
LOG.info(f"Setting current epoch to {epoch}")
self.trainer.fit_loop.epoch_progress.current.completed = epoch
assert self.current_epoch == epoch, f"{self.current_epoch} != {epoch}"
def set_global_step(self, global_step: int):
LOG.info(f"Setting global step to {global_step}")
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed = (
global_step
)
self.trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = (
global_step
)
assert self.global_step == global_step, f"{self.global_step} != {global_step}"
def set_total_batch_idx(self, total_batch_idx: int):
LOG.info(f"Setting total batch idx to {total_batch_idx}")
self.trainer.fit_loop.epoch_loop.batch_progress.total.ready = (
total_batch_idx + 1
)
self.trainer.fit_loop.epoch_loop.batch_progress.total.completed = (
total_batch_idx
)
assert (
self.total_batch_idx == total_batch_idx + 1
), f"{self.total_batch_idx} != {total_batch_idx + 1}"
@property
def total_batch_idx(self) -> int:
return self.trainer.fit_loop.epoch_loop.total_batch_idx + 1 However, this breaks support for Is there a way to properly override the current epoch it "starts" (or in this case, continues) training from? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
After looking at the Lightning code, I managed to fix it by also setting |
Beta Was this translation helpful? Give feedback.
After looking at the Lightning code, I managed to fix it by also setting
self.trainer.fit_loop.epoch_progress.current.processed
to a specific epoch 😁