Skip to content

Commit 363b043

Browse files
committed
Extend train epoch schedule by warmup_epochs if warmup_prefix enable, allows schedule to reach end w/ prefix enabledy
1 parent 7f0c1b1 commit 363b043

File tree

4 files changed

+14
-7
lines changed

4 files changed

+14
-7
lines changed

timm/scheduler/cosine_lr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def _get_lr(self, t: int) -> List[float]:
111111
def get_cycle_length(self, cycles=0):
112112
cycles = max(1, cycles or self.cycle_limit)
113113
if self.cycle_mul == 1.0:
114-
return self.t_initial * cycles
114+
t = self.t_initial * cycles
115115
else:
116-
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
116+
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
117+
return t + self.warmup_t if self.warmup_prefix else t

timm/scheduler/poly_lr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _get_lr(self, t: int) -> List[float]:
107107
def get_cycle_length(self, cycles=0):
108108
cycles = max(1, cycles or self.cycle_limit)
109109
if self.cycle_mul == 1.0:
110-
return self.t_initial * cycles
110+
t = self.t_initial * cycles
111111
else:
112-
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
112+
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
113+
return t + self.warmup_t if self.warmup_prefix else t

timm/scheduler/scheduler_factory.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,15 @@ def create_scheduler_v2(
196196
)
197197

198198
if hasattr(lr_scheduler, 'get_cycle_length'):
199-
# for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
199+
# For cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
200+
# NOTE: Warmup prefix added in get_cycle_lengths() if enabled
200201
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
201202
if step_on_epochs:
202203
num_epochs = t_with_cycles_and_cooldown
203204
else:
204205
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
206+
else:
207+
if warmup_prefix:
208+
num_epochs += warmup_epochs
205209

206210
return lr_scheduler, num_epochs

timm/scheduler/tanh_lr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def _get_lr(self, t: int) -> List[float]:
108108
def get_cycle_length(self, cycles=0):
109109
cycles = max(1, cycles or self.cycle_limit)
110110
if self.cycle_mul == 1.0:
111-
return self.t_initial * cycles
111+
t = self.t_initial * cycles
112112
else:
113-
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
113+
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
114+
return t + self.warmup_t if self.warmup_prefix else t

0 commit comments

Comments
 (0)