30 lines
1.2 KiB
Python
30 lines
1.2 KiB
Python
import math
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
class LinearWarmUpCosineAnnealingLR(LambdaLR):
|
|
def __init__(self, optimizer, *, peak, final, warm_up_steps, max_steps, init=1e-8, offset=0, epoch_size=0, **kwargs):
|
|
assert peak >= final >= init >= 0
|
|
assert max_steps >= warm_up_steps
|
|
self.init = init
|
|
self.peak = peak
|
|
self.final = final
|
|
self.warm_up_steps = warm_up_steps
|
|
self.max_steps = max_steps
|
|
self.offset = offset
|
|
self.epoch_size = epoch_size
|
|
kwargs['optimizer'] = optimizer
|
|
kwargs['lr_lambda'] = self._step_inner
|
|
super().__init__(**kwargs)
|
|
|
|
def _step_inner(self, steps):
|
|
steps += self.offset
|
|
if self.epoch_size > 0:
|
|
steps %= self.epoch_size
|
|
if self.warm_up_steps > 0 and steps < self.warm_up_steps:
|
|
return self.init + (self.peak - self.init) / self.warm_up_steps * steps
|
|
if steps < self.max_steps:
|
|
cos_steps = steps - self.warm_up_steps
|
|
cos_max_steps = self.max_steps - self.warm_up_steps
|
|
return self.final + 0.5 * (self.peak - self.final) * (1 + math.cos(cos_steps / cos_max_steps * math.pi))
|
|
return self.final
|