File tree Expand file tree Collapse file tree 4 files changed +14
-7
lines changed Expand file tree Collapse file tree 4 files changed +14
-7
lines changed Original file line number Diff line number Diff line change @@ -111,6 +111,7 @@ def _get_lr(self, t: int) -> List[float]:
111
111
def get_cycle_length (self , cycles = 0 ):
112
112
cycles = max (1 , cycles or self .cycle_limit )
113
113
if self .cycle_mul == 1.0 :
114
- return self .t_initial * cycles
114
+ t = self .t_initial * cycles
115
115
else :
116
- return int (math .floor (- self .t_initial * (self .cycle_mul ** cycles - 1 ) / (1 - self .cycle_mul )))
116
+ t = int (math .floor (- self .t_initial * (self .cycle_mul ** cycles - 1 ) / (1 - self .cycle_mul )))
117
+ return t + self .warmup_t if self .warmup_prefix else t
Original file line number Diff line number Diff line change @@ -107,6 +107,7 @@ def _get_lr(self, t: int) -> List[float]:
107
107
def get_cycle_length (self , cycles = 0 ):
108
108
cycles = max (1 , cycles or self .cycle_limit )
109
109
if self .cycle_mul == 1.0 :
110
- return self .t_initial * cycles
110
+ t = self .t_initial * cycles
111
111
else :
112
- return int (math .floor (- self .t_initial * (self .cycle_mul ** cycles - 1 ) / (1 - self .cycle_mul )))
112
+ t = int (math .floor (- self .t_initial * (self .cycle_mul ** cycles - 1 ) / (1 - self .cycle_mul )))
113
+ return t + self .warmup_t if self .warmup_prefix else t
Original file line number Diff line number Diff line change @@ -196,11 +196,15 @@ def create_scheduler_v2(
196
196
)
197
197
198
198
if hasattr (lr_scheduler , 'get_cycle_length' ):
199
- # for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
199
+ # For cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
200
+ # NOTE: Warmup prefix added in get_cycle_lengths() if enabled
200
201
t_with_cycles_and_cooldown = lr_scheduler .get_cycle_length () + cooldown_t
201
202
if step_on_epochs :
202
203
num_epochs = t_with_cycles_and_cooldown
203
204
else :
204
205
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
206
+ else :
207
+ if warmup_prefix :
208
+ num_epochs += warmup_epochs
205
209
206
210
return lr_scheduler , num_epochs
Original file line number Diff line number Diff line change @@ -108,6 +108,7 @@ def _get_lr(self, t: int) -> List[float]:
108
108
def get_cycle_length (self , cycles = 0 ):
109
109
cycles = max (1 , cycles or self .cycle_limit )
110
110
if self .cycle_mul == 1.0 :
111
- return self .t_initial * cycles
111
+ t = self .t_initial * cycles
112
112
else :
113
- return int (math .floor (- self .t_initial * (self .cycle_mul ** cycles - 1 ) / (1 - self .cycle_mul )))
113
+ t = int (math .floor (- self .t_initial * (self .cycle_mul ** cycles - 1 ) / (1 - self .cycle_mul )))
114
+ return t + self .warmup_t if self .warmup_prefix else t
You can’t perform that action at this time.
0 commit comments