171 lines
6.2 KiB
Python
171 lines
6.2 KiB
Python
"""
|
|
Parent fine-tuning agent class.
|
|
|
|
"""
|
|
|
|
import os
|
|
import numpy as np
|
|
from omegaconf import OmegaConf
|
|
import torch
|
|
import hydra
|
|
import logging
|
|
import wandb
|
|
import random
|
|
|
|
log = logging.getLogger(__name__)
|
|
from env.gym_utils import make_async
|
|
|
|
|
|
class TrainAgent:
|
|
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.device = cfg.device
|
|
self.seed = cfg.get("seed", 42)
|
|
random.seed(self.seed)
|
|
np.random.seed(self.seed)
|
|
torch.manual_seed(self.seed)
|
|
|
|
# Wandb
|
|
self.use_wandb = cfg.wandb is not None
|
|
if cfg.wandb is not None:
|
|
wandb.init(
|
|
entity=cfg.wandb.entity,
|
|
project=cfg.wandb.project,
|
|
name=cfg.wandb.run,
|
|
config=OmegaConf.to_container(cfg, resolve=True),
|
|
)
|
|
|
|
# Make vectorized env
|
|
self.env_name = cfg.env.name
|
|
env_type = cfg.env.get("env_type", None)
|
|
self.venv = make_async(
|
|
cfg.env.name,
|
|
env_type=env_type,
|
|
num_envs=cfg.env.n_envs,
|
|
asynchronous=True,
|
|
max_episode_steps=cfg.env.max_episode_steps,
|
|
wrappers=cfg.env.get("wrappers", None),
|
|
robomimic_env_cfg_path=cfg.get("robomimic_env_cfg_path", None),
|
|
shape_meta=cfg.get("shape_meta", None),
|
|
use_image_obs=cfg.env.get("use_image_obs", False),
|
|
render=cfg.env.get("render", False),
|
|
render_offscreen=cfg.env.get("save_video", False),
|
|
obs_dim=cfg.obs_dim,
|
|
action_dim=cfg.action_dim,
|
|
**cfg.env.specific if "specific" in cfg.env else {},
|
|
)
|
|
if not env_type == "furniture":
|
|
self.venv.seed(
|
|
[self.seed + i for i in range(cfg.env.n_envs)]
|
|
) # otherwise parallel envs might have the same initial states!
|
|
# isaacgym environments do not need seeding
|
|
self.n_envs = cfg.env.n_envs
|
|
self.n_cond_step = cfg.cond_steps
|
|
self.obs_dim = cfg.obs_dim
|
|
self.action_dim = cfg.action_dim
|
|
self.act_steps = cfg.act_steps
|
|
self.horizon_steps = cfg.horizon_steps
|
|
self.max_episode_steps = cfg.env.max_episode_steps
|
|
self.reset_at_iteration = cfg.env.get("reset_at_iteration", True)
|
|
self.save_full_observations = cfg.env.get("save_full_observations", False)
|
|
self.furniture_sparse_reward = (
|
|
cfg.env.specific.get("sparse_reward", False)
|
|
if "specific" in cfg.env
|
|
else False
|
|
) # furniture specific, for best reward calculation
|
|
|
|
# Batch size for gradient update
|
|
self.batch_size: int = cfg.train.batch_size
|
|
|
|
# Build model and load checkpoint
|
|
self.model = hydra.utils.instantiate(cfg.model)
|
|
|
|
# Training params
|
|
self.itr = 0
|
|
self.n_train_itr = cfg.train.n_train_itr
|
|
self.val_freq = cfg.train.val_freq
|
|
self.force_train = cfg.train.get("force_train", False)
|
|
self.n_steps = cfg.train.n_steps
|
|
self.best_reward_threshold_for_success = (
|
|
len(self.venv.pairs_to_assemble)
|
|
if env_type == "furniture"
|
|
else cfg.env.best_reward_threshold_for_success
|
|
)
|
|
self.max_grad_norm = cfg.train.get("max_grad_norm", None)
|
|
|
|
# Logging, rendering, checkpoints
|
|
self.logdir = cfg.logdir
|
|
self.render_dir = os.path.join(self.logdir, "render")
|
|
self.checkpoint_dir = os.path.join(self.logdir, "checkpoint")
|
|
self.result_path = os.path.join(self.logdir, "result.pkl")
|
|
os.makedirs(self.render_dir, exist_ok=True)
|
|
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
self.save_trajs = cfg.train.get("save_trajs", False)
|
|
self.log_freq = cfg.train.get("log_freq", 1)
|
|
self.save_model_freq = cfg.train.save_model_freq
|
|
self.render_freq = cfg.train.render.freq
|
|
self.n_render = cfg.train.render.num
|
|
self.render_video = cfg.env.get("save_video", False)
|
|
assert self.n_render <= self.n_envs, "n_render must be <= n_envs"
|
|
assert not (
|
|
self.n_render <= 0 and self.render_video
|
|
), "Need to set n_render > 0 if saving video"
|
|
self.traj_plotter = (
|
|
hydra.utils.instantiate(cfg.train.plotter)
|
|
if "plotter" in cfg.train
|
|
else None
|
|
)
|
|
|
|
def run(self):
|
|
pass
|
|
|
|
def save_model(self):
|
|
"""
|
|
saves model to disk; no ema
|
|
"""
|
|
data = {
|
|
"itr": self.itr,
|
|
"model": self.model.state_dict(),
|
|
} # right now `model` includes weights for `network`, `actor`, `actor_ft`. Weights for `network` is redundant, and we can use `actor` weights as the base policy (earlier denoising steps) and `actor_ft` weights as the fine-tuned policy (later denoising steps) during evaluation.
|
|
savepath = os.path.join(self.checkpoint_dir, f"state_{self.itr}.pt")
|
|
torch.save(data, savepath)
|
|
log.info(f"Saved model to {savepath}")
|
|
|
|
def load(self, itr):
|
|
"""
|
|
loads model from disk
|
|
"""
|
|
loadpath = os.path.join(self.checkpoint_dir, f"state_{itr}.pt")
|
|
data = torch.load(loadpath, weights_only=True)
|
|
|
|
self.itr = data["itr"]
|
|
self.model.load_state_dict(data["model"])
|
|
|
|
def reset_env_all(self, verbose=False, options_venv=None, **kwargs):
|
|
if options_venv is None:
|
|
options_venv = [
|
|
{k: v for k, v in kwargs.items()} for _ in range(self.n_envs)
|
|
]
|
|
obs_venv = self.venv.reset_arg(options_list=options_venv)
|
|
# convert to OrderedDict if obs_venv is a list of dict
|
|
if isinstance(obs_venv, list):
|
|
obs_venv = {
|
|
key: np.stack([obs_venv[i][key] for i in range(self.n_envs)])
|
|
for key in obs_venv[0].keys()
|
|
}
|
|
if verbose:
|
|
for index in range(self.n_envs):
|
|
logging.info(
|
|
f"<-- Reset environment {index} with options {options_venv[index]}"
|
|
)
|
|
return obs_venv
|
|
|
|
def reset_env(self, env_ind, verbose=False):
|
|
task = {}
|
|
obs = self.venv.reset_one_arg(env_ind=env_ind, options=task)
|
|
if verbose:
|
|
logging.info(f"<-- Reset environment {env_ind} with task {task}")
|
|
return obs
|