""" Parent pre-training agent class. """ import os import random import numpy as np from omegaconf import OmegaConf import torch import hydra import logging import wandb from copy import deepcopy log = logging.getLogger(__name__) from util.scheduler import CosineAnnealingWarmupRestarts DEVICE = "cuda:0" def to_device(x, device=DEVICE): if torch.is_tensor(x): return x.to(device) elif type(x) is dict: return {k: to_device(v, device) for k, v in x.items()} else: print(f"Unrecognized type in `to_device`: {type(x)}") def batch_to_device(batch, device="cuda:0"): vals = [to_device(getattr(batch, field), device) for field in batch._fields] return type(batch)(*vals) class EMA: """ Empirical moving average """ def __init__(self, cfg): super().__init__() self.beta = cfg.decay def update_model_average(self, ma_model, current_model): for current_params, ma_params in zip( current_model.parameters(), ma_model.parameters() ): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new class PreTrainAgent: def __init__(self, cfg): super().__init__() self.seed = cfg.get("seed", 42) random.seed(self.seed) np.random.seed(self.seed) torch.manual_seed(self.seed) # Wandb self.use_wandb = cfg.wandb is not None if cfg.wandb is not None: wandb.init( entity=cfg.wandb.entity, project=cfg.wandb.project, name=cfg.wandb.run, config=OmegaConf.to_container(cfg, resolve=True), ) # Build model self.model = hydra.utils.instantiate(cfg.model) self.ema = EMA(cfg.ema) self.ema_model = deepcopy(self.model) # Training params self.n_epochs = cfg.train.n_epochs self.batch_size = cfg.train.batch_size self.update_ema_freq = cfg.train.update_ema_freq self.epoch_start_ema = cfg.train.epoch_start_ema self.val_freq = cfg.train.get("val_freq", 100) # Logging, checkpoints self.logdir = cfg.logdir self.checkpoint_dir = os.path.join(self.logdir, "checkpoint") os.makedirs(self.checkpoint_dir, exist_ok=True) self.log_freq = cfg.train.get("log_freq", 1) self.save_model_freq = cfg.train.save_model_freq # Build dataset self.dataset_train = hydra.utils.instantiate(cfg.train_dataset) self.dataloader_train = torch.utils.data.DataLoader( self.dataset_train, batch_size=self.batch_size, num_workers=4 if self.dataset_train.device == "cpu" else 0, shuffle=True, pin_memory=True if self.dataset_train.device == "cpu" else False, ) self.dataloader_val = None if "train_split" in cfg.train and cfg.train.train_split < 1: val_indices = self.dataset_train.set_train_val_split(cfg.train.train_split) self.dataset_val = deepcopy(self.dataset_train) self.dataset_val.set_indices(val_indices) self.dataloader_val = torch.utils.data.DataLoader( self.dataset_val, batch_size=self.batch_size, num_workers=4 if self.dataset_val.device == "cpu" else 0, shuffle=True, pin_memory=True if self.dataset_val.device == "cpu" else False, ) self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=cfg.train.learning_rate, weight_decay=cfg.train.weight_decay, ) self.lr_scheduler = CosineAnnealingWarmupRestarts( self.optimizer, first_cycle_steps=cfg.train.lr_scheduler.first_cycle_steps, cycle_mult=1.0, max_lr=cfg.train.learning_rate, min_lr=cfg.train.lr_scheduler.min_lr, warmup_steps=cfg.train.lr_scheduler.warmup_steps, gamma=1.0, ) self.reset_parameters() def run(self): raise NotImplementedError def reset_parameters(self): self.ema_model.load_state_dict(self.model.state_dict()) def step_ema(self): if self.epoch < self.epoch_start_ema: self.reset_parameters() return self.ema.update_model_average(self.ema_model, self.model) def save_model(self): """ saves model and ema to disk; """ data = { "epoch": self.epoch, "model": self.model.state_dict(), "ema": self.ema_model.state_dict(), } savepath = os.path.join(self.checkpoint_dir, f"state_{self.epoch}.pt") torch.save(data, savepath) log.info(f"Saved model to {savepath}") def load(self, epoch): """ loads model and ema from disk """ loadpath = os.path.join(self.checkpoint_dir, f"state_{epoch}.pt") data = torch.load(loadpath, weights_only=True) self.epoch = data["epoch"] self.model.load_state_dict(data["model"]) self.ema_model.load_state_dict(data["ema"])