dppo/agent/finetune/train_calql_agent.py
Allen Z. Ren dc8e0c9edc
v0.6 (#18)
* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* Add Proficient Human (PH) Configs and Pipeline (#16)

* fix missing cfg

* add ph config

* fix how terminated flags are added to buffer in ibrl

* add ph config

* offline calql for 1M gradient updates

* bug fix: number of calql online gradient steps is the number of new transitions collected

* add sample config for DPPO with ta=1

* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* fix diffusion loss when predicting initial noise

* fix dppo inds

* fix typo

* remove print statement

---------

Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
Co-authored-by: allenzren <allen.ren@princeton.edu>

* update robomimic configs

* better calql formulation

* optimize calql and ibrl training

* optimize data transfer in ppo agents

* add kitchen configs

* re-organize config folders, rerun calql and rlpd

* add scratch gym locomotion configs

* add kitchen installation dependencies

* use truncated for termination in furniture env

* update furniture and gym configs

* update README and dependencies with kitchen

* add url for new data and checkpoints

* update demo RL configs

* update batch sizes for furniture unet configs

* raise error about dropout in residual mlp

* fix observation bug in bc loss

---------

Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com>
Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
2024-10-30 19:58:06 -04:00

504 lines
21 KiB
Python

"""
Reinforcement Learning with Prior Data (RLPD) agent training script.
Does not support image observations right now.
"""
import os
import pickle
import numpy as np
import torch
import logging
import wandb
import hydra
from collections import deque
log = logging.getLogger(__name__)
from util.timer import Timer
from agent.finetune.train_agent import TrainAgent
from util.scheduler import CosineAnnealingWarmupRestarts
class TrainCalQLAgent(TrainAgent):
def __init__(self, cfg):
super().__init__(cfg)
assert self.n_envs == 1, "Cal-QL only supports single env for now"
# Train mode (offline or online)
self.train_online = cfg.train.train_online
# Build dataset
self.dataset_offline = hydra.utils.instantiate(cfg.offline_dataset)
# note the discount factor gamma here is applied to reward every act_steps, instead of every env step
self.gamma = cfg.train.gamma
# Optimizer
self.actor_optimizer = torch.optim.AdamW(
self.model.network.parameters(),
lr=cfg.train.actor_lr,
weight_decay=cfg.train.actor_weight_decay,
)
self.actor_lr_scheduler = CosineAnnealingWarmupRestarts(
self.actor_optimizer,
first_cycle_steps=cfg.train.actor_lr_scheduler.first_cycle_steps,
cycle_mult=1.0,
max_lr=cfg.train.actor_lr,
min_lr=cfg.train.actor_lr_scheduler.min_lr,
warmup_steps=cfg.train.actor_lr_scheduler.warmup_steps,
gamma=1.0,
)
self.critic_optimizer = torch.optim.AdamW(
self.model.critic.parameters(),
lr=cfg.train.critic_lr,
weight_decay=cfg.train.critic_weight_decay,
)
self.critic_lr_scheduler = CosineAnnealingWarmupRestarts(
self.critic_optimizer,
first_cycle_steps=cfg.train.critic_lr_scheduler.first_cycle_steps,
cycle_mult=1.0,
max_lr=cfg.train.critic_lr,
min_lr=cfg.train.critic_lr_scheduler.min_lr,
warmup_steps=cfg.train.critic_lr_scheduler.warmup_steps,
gamma=1.0,
)
# Perturbation scale
self.target_ema_rate = cfg.train.target_ema_rate
# Number of random actions to sample for Cal-QL
self.n_random_actions = cfg.train.n_random_actions
# Reward scale
self.scale_reward_factor = cfg.train.scale_reward_factor
# Number of critic updates
self.num_update = cfg.train.num_update
# Buffer size
self.buffer_size = cfg.train.buffer_size
# Online only configs
if self.train_online:
# number of episode to colect per epoch for training
self.n_episode_per_epoch = cfg.train.n_episode_per_epoch
# Eval episodes
self.n_eval_episode = cfg.train.n_eval_episode
# Exploration steps at the beginning - using randomly sampled action
self.n_explore_steps = cfg.train.n_explore_steps
# Initialize temperature parameter for entropy
init_temperature = cfg.train.init_temperature
self.log_alpha = torch.tensor(np.log(init_temperature)).to(self.device)
self.log_alpha.requires_grad = True
self.automatic_entropy_tuning = cfg.train.automatic_entropy_tuning
self.target_entropy = cfg.train.target_entropy
self.log_alpha_optimizer = torch.optim.Adam(
[self.log_alpha],
lr=cfg.train.critic_lr,
)
def run(self):
# make a FIFO replay buffer for obs, action, and reward
obs_buffer = deque(maxlen=self.buffer_size)
next_obs_buffer = deque(maxlen=self.buffer_size)
action_buffer = deque(maxlen=self.buffer_size)
reward_buffer = deque(maxlen=self.buffer_size)
reward_to_go_buffer = deque(maxlen=self.buffer_size)
terminated_buffer = deque(maxlen=self.buffer_size)
if not self.train_online:
obs_array = np.array(obs_buffer)
next_obs_array = np.array(next_obs_buffer)
actions_array = np.array(action_buffer)
rewards_array = np.array(reward_buffer)
reward_to_go_array = np.array(reward_to_go_buffer)
terminated_array = np.array(terminated_buffer)
# load offline dataset into replay buffer
dataloader_offline = torch.utils.data.DataLoader(
self.dataset_offline,
batch_size=len(self.dataset_offline),
drop_last=False,
)
for batch in dataloader_offline:
actions, states_and_next, rewards, terminated, reward_to_go = batch
states = states_and_next["state"]
next_states = states_and_next["next_state"]
obs_buffer_off = states.cpu().numpy()
next_obs_buffer_off = next_states.cpu().numpy()
action_buffer_off = actions.cpu().numpy()
reward_buffer_off = rewards.cpu().numpy().flatten()
reward_to_go_buffer_off = reward_to_go.cpu().numpy().flatten()
terminated_buffer_off = terminated.cpu().numpy().flatten()
# Start training loop
timer = Timer()
run_results = []
cnt_train_step = 0
done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr:
if self.itr % 1000 == 0:
print(f"Finished training iteration {self.itr} of {self.n_train_itr}")
# Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
options_venv = [{} for _ in range(self.n_envs)]
if self.itr % self.render_freq == 0 and self.render_video:
for env_ind in range(self.n_render):
options_venv[env_ind]["video_path"] = os.path.join(
self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
)
# Define train or eval - all envs restart
eval_mode = (
self.itr % self.val_freq == 0
and self.itr >= self.n_explore_steps
and not self.force_train
)
# during eval, we collect a fixed number of episodes, so we set n_steps to a large value
if eval_mode:
n_steps = int(1e5)
elif not self.train_online:
n_steps = 0
else:
n_steps = int(1e5) # use episodes
self.model.eval() if eval_mode else self.model.train()
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) at the beginning
firsts_trajs = np.zeros((n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or self.itr == 0:
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
else:
# if done at the end of last iteration, the envs are just reset
firsts_trajs[0] = done_venv
reward_trajs = np.zeros((n_steps, self.n_envs))
# Collect a set of trajectories from env
cnt_episode = 0
for step in range(n_steps):
if step % 100 == 0:
print(f"Completed environment step {step}")
# Select action
if self.itr < self.n_explore_steps:
action_venv = self.venv.action_space.sample()
else:
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = (
self.model(
cond=cond,
deterministic=eval_mode,
)
.cpu()
.numpy()
) # n_env x horizon x act
action_venv = samples[:, : self.act_steps]
# Apply multi-step action
(
obs_venv,
reward_venv,
terminated_venv,
truncated_venv,
info_venv,
) = self.venv.step(action_venv)
done_venv = terminated_venv | truncated_venv
reward_trajs[step] = reward_venv
firsts_trajs[step + 1] = done_venv
# add to buffer in train mode
if not eval_mode:
for i in range(self.n_envs):
obs_buffer.append(prev_obs_venv["state"][i])
if truncated_venv[i]:
next_obs_buffer.append(info_venv[i]["final_obs"]["state"])
else: # first obs in new episode
next_obs_buffer.append(obs_venv["state"][i])
action_buffer.append(action_venv[i])
reward_buffer.extend(
(reward_venv * self.scale_reward_factor).tolist()
)
terminated_buffer.extend(terminated_venv.tolist())
# update for next step
prev_obs_venv = obs_venv
# count steps --- not acounting for done within action chunk
cnt_train_step += self.n_envs * self.act_steps if not eval_mode else 0
# check if enough eval episodes are done
cnt_episode += np.sum(done_venv)
if eval_mode and cnt_episode >= self.n_eval_episode:
break
if not eval_mode and cnt_episode >= self.n_episode_per_epoch:
break
# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
episodes_start_end = []
for env_ind in range(self.n_envs):
env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
for i in range(len(env_steps) - 1):
start = env_steps[i]
end = env_steps[i + 1]
if end - start > 1:
episodes_start_end.append((env_ind, start, end - 1))
if len(episodes_start_end) > 0:
reward_trajs_split = [
reward_trajs[start : end + 1, env_ind]
for env_ind, start, end in episodes_start_end
]
# compute episode returns
returns_trajs_split = [
np.zeros_like(reward_trajs) for reward_trajs in reward_trajs_split
]
for traj_rewards, traj_returns in zip(
reward_trajs_split, returns_trajs_split
):
prev_return = 0
for t in range(len(traj_rewards)):
traj_returns[-t - 1] = (
traj_rewards[-t - 1] + self.gamma * prev_return
)
prev_return = traj_returns[-t - 1]
# flatten (note: only works for single env!)
returns_trajs_split = np.concatenate(returns_trajs_split)
# extend buffer
reward_to_go_buffer.extend(returns_trajs_split)
num_episode_finished = len(reward_trajs_split)
episode_reward = np.array(
[np.sum(reward_traj) for reward_traj in reward_trajs_split]
)
episode_best_reward = np.array(
[
np.max(reward_traj) / self.act_steps
for reward_traj in reward_trajs_split
]
)
avg_episode_reward = np.mean(episode_reward)
avg_best_reward = np.mean(episode_best_reward)
success_rate = np.mean(
episode_best_reward >= self.best_reward_threshold_for_success
)
else:
episode_reward = np.array([])
num_episode_finished = 0
avg_episode_reward = 0
avg_best_reward = 0
success_rate = 0
# Update models
if not eval_mode and self.itr >= self.n_explore_steps:
# TODO: is this slow in online?
if self.train_online:
obs_array = np.array(obs_buffer)
next_obs_array = np.array(next_obs_buffer)
actions_array = np.array(action_buffer)
rewards_array = np.array(reward_buffer)
reward_to_go_array = np.array(reward_to_go_buffer)
terminated_array = np.array(terminated_buffer)
# override num_update
if self.train_online:
# the amount of new transitions(single env)
num_update = len(reward_trajs_split[0])
else:
num_update = self.num_update
for _ in range(num_update):
# Sample from OFFLINE buffer
inds = np.random.choice(
len(obs_buffer_off),
self.batch_size // 2 if self.train_online else self.batch_size,
)
obs_b = (
torch.from_numpy(obs_buffer_off[inds]).float().to(self.device)
)
next_obs_b = (
torch.from_numpy(next_obs_buffer_off[inds])
.float()
.to(self.device)
)
actions_b = (
torch.from_numpy(action_buffer_off[inds])
.float()
.to(self.device)
)
rewards_b = (
torch.from_numpy(reward_buffer_off[inds])
.float()
.to(self.device)
)
terminated_b = (
torch.from_numpy(terminated_buffer_off[inds])
.float()
.to(self.device)
)
reward_to_go_b = (
torch.from_numpy(reward_to_go_buffer_off[inds])
.float()
.to(self.device)
)
# Sample from ONLINE buffer
if self.train_online:
inds = np.random.choice(len(obs_buffer), self.batch_size // 2)
obs_b_on = (
torch.from_numpy(obs_array[inds]).float().to(self.device)
)
next_obs_b_on = (
torch.from_numpy(next_obs_array[inds])
.float()
.to(self.device)
)
actions_b_on = (
torch.from_numpy(actions_array[inds])
.float()
.to(self.device)
)
rewards_b_on = (
torch.from_numpy(rewards_array[inds])
.float()
.to(self.device)
)
terminated_b_on = (
torch.from_numpy(terminated_array[inds])
.float()
.to(self.device)
)
reward_to_go_b_on = (
torch.from_numpy(reward_to_go_array[inds])
.float()
.to(self.device)
)
# merge offline and online data
obs_b = torch.cat([obs_b, obs_b_on], dim=0)
next_obs_b = torch.cat([next_obs_b, next_obs_b_on], dim=0)
actions_b = torch.cat([actions_b, actions_b_on], dim=0)
rewards_b = torch.cat([rewards_b, rewards_b_on], dim=0)
terminated_b = torch.cat([terminated_b, terminated_b_on], dim=0)
reward_to_go_b = torch.cat(
[reward_to_go_b, reward_to_go_b_on], dim=0
)
# Get a random action for Cal-QL
random_actions = (
torch.rand(
(
self.batch_size,
self.n_random_actions,
self.horizon_steps,
self.action_dim,
)
).to(self.device)
* 2
- 1
) # scale to [-1, 1]
# Update critic
alpha = self.log_alpha.exp().item()
loss_critic = self.model.loss_critic(
{"state": obs_b},
{"state": next_obs_b},
actions_b,
random_actions,
rewards_b,
reward_to_go_b,
terminated_b,
self.gamma,
)
self.critic_optimizer.zero_grad()
loss_critic.backward()
self.critic_optimizer.step()
# Update target critic
self.model.update_target_critic(self.target_ema_rate)
# Update actor
loss_actor = self.model.loss_actor(
{"state": obs_b},
alpha,
)
self.actor_optimizer.zero_grad()
loss_actor.backward()
self.actor_optimizer.step()
# Update temperature parameter
if self.automatic_entropy_tuning:
self.log_alpha_optimizer.zero_grad()
loss_alpha = self.model.loss_temperature(
{"state": obs_b},
self.log_alpha.exp(), # with grad
self.target_entropy,
)
loss_alpha.backward()
self.log_alpha_optimizer.step()
# Update lr
self.actor_lr_scheduler.step()
self.critic_lr_scheduler.step()
# Save model
if self.itr % self.save_model_freq == 0 or self.itr == self.n_train_itr - 1:
self.save_model()
# Log loss and save metrics
run_results.append(
{
"itr": self.itr,
"step": cnt_train_step,
}
)
if self.itr % self.log_freq == 0 and self.itr >= self.n_explore_steps:
time = timer()
run_results[-1]["time"] = time
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
)
if self.use_wandb:
wandb.log(
{
"success rate - eval": success_rate,
"avg episode reward - eval": avg_episode_reward,
"avg best reward - eval": avg_best_reward,
"num episode - eval": num_episode_finished,
},
step=self.itr,
commit=False,
)
run_results[-1]["eval_success_rate"] = success_rate
run_results[-1]["eval_episode_reward"] = avg_episode_reward
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: step {cnt_train_step:8d} | loss actor {loss_actor:8.4f} | loss critic {loss_critic:8.4f} | reward {avg_episode_reward:8.4f} | alpha {alpha:8.4f} | t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
{
"total env step": cnt_train_step,
"loss - actor": loss_actor,
"loss - critic": loss_critic,
"entropy coeff": alpha,
"avg episode reward - train": avg_episode_reward,
"num episode - train": num_episode_finished,
},
step=self.itr,
commit=True,
)
run_results[-1]["train_episode_reward"] = avg_episode_reward
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1