148 lines
5.6 KiB
Python
148 lines
5.6 KiB
Python
"""
|
|
MIT License
|
|
|
|
Copyright (c) 2022 Naoki Katsura
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in all
|
|
copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
SOFTWARE.
|
|
"""
|
|
|
|
# From https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup
|
|
|
|
import math
|
|
import torch
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
|
|
|
|
class CosineAnnealingWarmupRestarts(_LRScheduler):
|
|
"""
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
first_cycle_steps (int): First cycle step size.
|
|
cycle_mult(float): Cycle steps magnification. Default: -1.
|
|
max_lr(float): First cycle's max learning rate. Default: 0.1.
|
|
min_lr(float): Min learning rate. Default: 0.001.
|
|
warmup_steps(int): Linear warmup step size. Default: 0.
|
|
gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
optimizer: torch.optim.Optimizer,
|
|
first_cycle_steps: int,
|
|
cycle_mult: float = 1.0,
|
|
max_lr: float = 0.1,
|
|
min_lr: float = 0.001,
|
|
warmup_steps: int = 0,
|
|
gamma: float = 1.0,
|
|
last_epoch: int = -1,
|
|
):
|
|
assert warmup_steps < first_cycle_steps
|
|
|
|
self.first_cycle_steps = first_cycle_steps # first cycle step size
|
|
self.cycle_mult = cycle_mult # cycle steps magnification
|
|
self.base_max_lr = max_lr # first max learning rate
|
|
self.max_lr = max_lr # max learning rate in the current cycle
|
|
self.min_lr = min_lr # min learning rate
|
|
self.warmup_steps = warmup_steps # warmup step size
|
|
self.gamma = gamma # decrease rate of max learning rate by cycle
|
|
|
|
self.cur_cycle_steps = first_cycle_steps # first cycle step size
|
|
self.cycle = 0 # cycle count
|
|
self.step_in_cycle = last_epoch # step size of the current cycle
|
|
|
|
super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
|
|
|
|
# set learning rate min_lr
|
|
self.init_lr()
|
|
|
|
def init_lr(self):
|
|
self.base_lrs = []
|
|
for param_group in self.optimizer.param_groups:
|
|
param_group["lr"] = self.min_lr
|
|
self.base_lrs.append(self.min_lr)
|
|
|
|
def get_lr(self):
|
|
if self.step_in_cycle == -1:
|
|
return self.base_lrs
|
|
elif self.step_in_cycle < self.warmup_steps:
|
|
return [
|
|
(self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps
|
|
+ base_lr
|
|
for base_lr in self.base_lrs
|
|
]
|
|
else:
|
|
return [
|
|
base_lr
|
|
+ (self.max_lr - base_lr)
|
|
* (
|
|
1
|
|
+ math.cos(
|
|
math.pi
|
|
* (self.step_in_cycle - self.warmup_steps)
|
|
/ (self.cur_cycle_steps - self.warmup_steps)
|
|
)
|
|
)
|
|
/ 2
|
|
for base_lr in self.base_lrs
|
|
]
|
|
|
|
def step(self, epoch=None):
|
|
if epoch is None:
|
|
epoch = self.last_epoch + 1
|
|
self.step_in_cycle = self.step_in_cycle + 1
|
|
if self.step_in_cycle >= self.cur_cycle_steps:
|
|
self.cycle += 1
|
|
self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
|
|
self.cur_cycle_steps = (
|
|
int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult)
|
|
+ self.warmup_steps
|
|
)
|
|
else:
|
|
if epoch >= self.first_cycle_steps:
|
|
if self.cycle_mult == 1.0:
|
|
self.step_in_cycle = epoch % self.first_cycle_steps
|
|
self.cycle = epoch // self.first_cycle_steps
|
|
else:
|
|
n = int(
|
|
math.log(
|
|
(
|
|
epoch / self.first_cycle_steps * (self.cycle_mult - 1)
|
|
+ 1
|
|
),
|
|
self.cycle_mult,
|
|
)
|
|
)
|
|
self.cycle = n
|
|
self.step_in_cycle = epoch - int(
|
|
self.first_cycle_steps
|
|
* (self.cycle_mult**n - 1)
|
|
/ (self.cycle_mult - 1)
|
|
)
|
|
self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (
|
|
n
|
|
)
|
|
else:
|
|
self.cur_cycle_steps = self.first_cycle_steps
|
|
self.step_in_cycle = epoch
|
|
|
|
self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
|
|
self.last_epoch = math.floor(epoch)
|
|
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
|
param_group["lr"] = lr
|