169 lines
5.2 KiB
Python
169 lines
5.2 KiB
Python
"""
|
|
Parent pre-training agent class.
|
|
|
|
"""
|
|
|
|
import os
|
|
import random
|
|
import numpy as np
|
|
from omegaconf import OmegaConf
|
|
import torch
|
|
import hydra
|
|
import logging
|
|
import wandb
|
|
from copy import deepcopy
|
|
|
|
log = logging.getLogger(__name__)
|
|
from util.scheduler import CosineAnnealingWarmupRestarts
|
|
|
|
DEVICE = "cuda:0"
|
|
|
|
|
|
def to_device(x, device=DEVICE):
|
|
if torch.is_tensor(x):
|
|
return x.to(device)
|
|
elif type(x) is dict:
|
|
return {k: to_device(v, device) for k, v in x.items()}
|
|
else:
|
|
print(f"Unrecognized type in `to_device`: {type(x)}")
|
|
|
|
|
|
def batch_to_device(batch, device="cuda:0"):
|
|
vals = [to_device(getattr(batch, field), device) for field in batch._fields]
|
|
return type(batch)(*vals)
|
|
|
|
|
|
class EMA:
|
|
"""
|
|
Empirical moving average
|
|
|
|
"""
|
|
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.beta = cfg.decay
|
|
|
|
def update_model_average(self, ma_model, current_model):
|
|
for current_params, ma_params in zip(
|
|
current_model.parameters(), ma_model.parameters()
|
|
):
|
|
old_weight, up_weight = ma_params.data, current_params.data
|
|
ma_params.data = self.update_average(old_weight, up_weight)
|
|
|
|
def update_average(self, old, new):
|
|
if old is None:
|
|
return new
|
|
return old * self.beta + (1 - self.beta) * new
|
|
|
|
|
|
class PreTrainAgent:
|
|
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.seed = cfg.get("seed", 42)
|
|
random.seed(self.seed)
|
|
np.random.seed(self.seed)
|
|
torch.manual_seed(self.seed)
|
|
|
|
# Wandb
|
|
self.use_wandb = cfg.wandb is not None
|
|
if cfg.wandb is not None:
|
|
wandb.init(
|
|
entity=cfg.wandb.entity,
|
|
project=cfg.wandb.project,
|
|
name=cfg.wandb.run,
|
|
config=OmegaConf.to_container(cfg, resolve=True),
|
|
)
|
|
|
|
# Build model
|
|
self.model = hydra.utils.instantiate(cfg.model)
|
|
self.ema = EMA(cfg.ema)
|
|
self.ema_model = deepcopy(self.model)
|
|
|
|
# Training params
|
|
self.n_epochs = cfg.train.n_epochs
|
|
self.batch_size = cfg.train.batch_size
|
|
self.update_ema_freq = cfg.train.update_ema_freq
|
|
self.epoch_start_ema = cfg.train.epoch_start_ema
|
|
self.val_freq = cfg.train.get("val_freq", 100)
|
|
|
|
# Logging, checkpoints
|
|
self.logdir = cfg.logdir
|
|
self.checkpoint_dir = os.path.join(self.logdir, "checkpoint")
|
|
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
self.log_freq = cfg.train.get("log_freq", 1)
|
|
self.save_model_freq = cfg.train.save_model_freq
|
|
|
|
# Build dataset
|
|
self.dataset_train = hydra.utils.instantiate(cfg.train_dataset)
|
|
self.dataloader_train = torch.utils.data.DataLoader(
|
|
self.dataset_train,
|
|
batch_size=self.batch_size,
|
|
num_workers=4 if self.dataset_train.device == "cpu" else 0,
|
|
shuffle=True,
|
|
pin_memory=True if self.dataset_train.device == "cpu" else False,
|
|
)
|
|
self.dataloader_val = None
|
|
if "train_split" in cfg.train and cfg.train.train_split < 1:
|
|
val_indices = self.dataset_train.set_train_val_split(cfg.train.train_split)
|
|
self.dataset_val = deepcopy(self.dataset_train)
|
|
self.dataset_val.set_indices(val_indices)
|
|
self.dataloader_val = torch.utils.data.DataLoader(
|
|
self.dataset_val,
|
|
batch_size=self.batch_size,
|
|
num_workers=4 if self.dataset_val.device == "cpu" else 0,
|
|
shuffle=True,
|
|
pin_memory=True if self.dataset_val.device == "cpu" else False,
|
|
)
|
|
self.optimizer = torch.optim.AdamW(
|
|
self.model.parameters(),
|
|
lr=cfg.train.learning_rate,
|
|
weight_decay=cfg.train.weight_decay,
|
|
)
|
|
self.lr_scheduler = CosineAnnealingWarmupRestarts(
|
|
self.optimizer,
|
|
first_cycle_steps=cfg.train.lr_scheduler.first_cycle_steps,
|
|
cycle_mult=1.0,
|
|
max_lr=cfg.train.learning_rate,
|
|
min_lr=cfg.train.lr_scheduler.min_lr,
|
|
warmup_steps=cfg.train.lr_scheduler.warmup_steps,
|
|
gamma=1.0,
|
|
)
|
|
self.reset_parameters()
|
|
|
|
def run(self):
|
|
raise NotImplementedError
|
|
|
|
def reset_parameters(self):
|
|
self.ema_model.load_state_dict(self.model.state_dict())
|
|
|
|
def step_ema(self):
|
|
if self.epoch < self.epoch_start_ema:
|
|
self.reset_parameters()
|
|
return
|
|
self.ema.update_model_average(self.ema_model, self.model)
|
|
|
|
def save_model(self):
|
|
"""
|
|
saves model and ema to disk;
|
|
"""
|
|
data = {
|
|
"epoch": self.epoch,
|
|
"model": self.model.state_dict(),
|
|
"ema": self.ema_model.state_dict(),
|
|
}
|
|
savepath = os.path.join(self.checkpoint_dir, f"state_{self.epoch}.pt")
|
|
torch.save(data, savepath)
|
|
log.info(f"Saved model to {savepath}")
|
|
|
|
def load(self, epoch):
|
|
"""
|
|
loads model and ema from disk
|
|
"""
|
|
loadpath = os.path.join(self.checkpoint_dir, f"state_{epoch}.pt")
|
|
data = torch.load(loadpath, weights_only=True)
|
|
|
|
self.epoch = data["epoch"]
|
|
self.model.load_state_dict(data["model"])
|
|
self.ema_model.load_state_dict(data["ema"])
|