dppo/agent/finetune/train_ppo_gaussian_img_agent.py
Allen Z. Ren dc8e0c9edc
v0.6 (#18)
* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* Add Proficient Human (PH) Configs and Pipeline (#16)

* fix missing cfg

* add ph config

* fix how terminated flags are added to buffer in ibrl

* add ph config

* offline calql for 1M gradient updates

* bug fix: number of calql online gradient steps is the number of new transitions collected

* add sample config for DPPO with ta=1

* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* fix diffusion loss when predicting initial noise

* fix dppo inds

* fix typo

* remove print statement

---------

Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
Co-authored-by: allenzren <allen.ren@princeton.edu>

* update robomimic configs

* better calql formulation

* optimize calql and ibrl training

* optimize data transfer in ppo agents

* add kitchen configs

* re-organize config folders, rerun calql and rlpd

* add scratch gym locomotion configs

* add kitchen installation dependencies

* use truncated for termination in furniture env

* update furniture and gym configs

* update README and dependencies with kitchen

* add url for new data and checkpoints

* update demo RL configs

* update batch sizes for furniture unet configs

* raise error about dropout in residual mlp

* fix observation bug in bc loss

---------

Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com>
Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
2024-10-30 19:58:06 -04:00

439 lines
19 KiB
Python

"""
PPO training for Gaussian/GMM policy with pixel observations.
"""
import os
import pickle
import einops
import numpy as np
import torch
import logging
import wandb
import math
log = logging.getLogger(__name__)
from util.timer import Timer
from agent.finetune.train_ppo_gaussian_agent import TrainPPOGaussianAgent
from model.common.modules import RandomShiftsAug
class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
def __init__(self, cfg):
super().__init__(cfg)
# Image randomization
self.augment = cfg.train.augment
if self.augment:
self.aug = RandomShiftsAug(pad=4)
# Set obs dim - we will save the different obs in batch in a dict
shape_meta = cfg.shape_meta
self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs.keys()}
# Gradient accumulation to deal with large GPU RAM usage
self.grad_accumulate = cfg.train.grad_accumulate
def run(self):
# Start training loop
timer = Timer()
run_results = []
cnt_train_step = 0
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr:
# Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
options_venv = [{} for _ in range(self.n_envs)]
if self.itr % self.render_freq == 0 and self.render_video:
for env_ind in range(self.n_render):
options_venv[env_ind]["video_path"] = os.path.join(
self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
)
# Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train()
last_itr_eval = eval_mode
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
else:
# if done at the end of last iteration, the envs are just reset
firsts_trajs[0] = done_venv
# Holder
obs_trajs = {
k: np.zeros(
(self.n_steps, self.n_envs, self.n_cond_step, *self.obs_dims[k])
)
for k in self.obs_dims
}
samples_trajs = np.zeros(
(
self.n_steps,
self.n_envs,
self.horizon_steps,
self.action_dim,
)
)
terminated_trajs = np.zeros((self.n_steps, self.n_envs))
reward_trajs = np.zeros((self.n_steps, self.n_envs))
# Collect a set of trajectories from env
for step in range(self.n_steps):
if step % 10 == 0:
print(f"Processed step {step} of {self.n_steps}")
# Select action
with torch.no_grad():
cond = {
key: torch.from_numpy(prev_obs_venv[key])
.float()
.to(self.device)
for key in self.obs_dims
}
samples = self.model(
cond=cond,
deterministic=eval_mode,
)
output_venv = samples.cpu().numpy()
action_venv = output_venv[:, : self.act_steps]
# Apply multi-step action
obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
self.venv.step(action_venv)
)
done_venv = terminated_venv | truncated_venv
for k in obs_trajs:
obs_trajs[k][step] = prev_obs_venv[k]
samples_trajs[step] = output_venv
reward_trajs[step] = reward_venv
terminated_trajs[step] = terminated_venv
firsts_trajs[step + 1] = done_venv
# update for next step
prev_obs_venv = obs_venv
# count steps --- not acounting for done within action chunk
cnt_train_step += self.n_envs * self.act_steps if not eval_mode else 0
# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
episodes_start_end = []
for env_ind in range(self.n_envs):
env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
for i in range(len(env_steps) - 1):
start = env_steps[i]
end = env_steps[i + 1]
if end - start > 1:
episodes_start_end.append((env_ind, start, end - 1))
if len(episodes_start_end) > 0:
reward_trajs_split = [
reward_trajs[start : end + 1, env_ind]
for env_ind, start, end in episodes_start_end
]
num_episode_finished = len(reward_trajs_split)
episode_reward = np.array(
[np.sum(reward_traj) for reward_traj in reward_trajs_split]
)
episode_best_reward = np.array(
[
np.max(reward_traj) / self.act_steps
for reward_traj in reward_trajs_split
]
)
avg_episode_reward = np.mean(episode_reward)
avg_best_reward = np.mean(episode_best_reward)
success_rate = np.mean(
episode_best_reward >= self.best_reward_threshold_for_success
)
else:
episode_reward = np.array([])
num_episode_finished = 0
avg_episode_reward = 0
avg_best_reward = 0
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update models
if not eval_mode:
with torch.no_grad():
# apply image randomization
obs_trajs["rgb"] = (
torch.from_numpy(obs_trajs["rgb"]).float().to(self.device)
)
obs_trajs["state"] = (
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
)
if self.augment:
rgb = einops.rearrange(
obs_trajs["rgb"],
"s e t c h w -> (s e t) c h w",
)
rgb = self.aug(rgb)
obs_trajs["rgb"] = einops.rearrange(
rgb,
"(s e t) c h w -> s e t c h w",
s=self.n_steps,
e=self.n_envs,
)
# Calculate value and logprobs - split into batches to prevent out of memory
num_split = math.ceil(
self.n_envs * self.n_steps / self.logprob_batch_size
)
obs_ts = [{} for _ in range(num_split)]
for k in obs_trajs:
obs_k = einops.rearrange(
obs_trajs[k],
"s e ... -> (s e) ...",
)
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
for i, obs_t in enumerate(obs_ts_k):
obs_ts[i][k] = obs_t
values_trajs = np.empty((0, self.n_envs))
for obs in obs_ts:
values = (
self.model.critic(obs, no_augment=True)
.cpu()
.numpy()
.flatten()
)
values_trajs = np.vstack(
(values_trajs, values.reshape(-1, self.n_envs))
)
samples_t = einops.rearrange(
torch.from_numpy(samples_trajs).float().to(self.device),
"s e h d -> (s e) h d",
)
samples_ts = torch.split(samples_t, self.logprob_batch_size, dim=0)
logprobs_trajs = np.empty((0))
for obs_t, samples_t in zip(obs_ts, samples_ts):
logprobs = (
self.model.get_logprobs(obs_t, samples_t)[0].cpu().numpy()
)
logprobs_trajs = np.concatenate(
(
logprobs_trajs,
logprobs.reshape(-1),
)
)
# normalize reward with running variance if specified
if self.reward_scale_running:
reward_trajs_transpose = self.running_reward_scaler(
reward=reward_trajs.T, first=firsts_trajs[:-1].T
)
reward_trajs = reward_trajs_transpose.T
# bootstrap value with GAE if not terminal - apply reward scaling with constant if specified
obs_venv_ts = {
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
for key in self.obs_dims
}
advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0
for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1:
nextvalues = (
self.model.critic(obs_venv_ts, no_augment=True)
.reshape(1, -1)
.cpu()
.numpy()
)
else:
nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - terminated_trajs[t]
# delta = r + gamma*V(st+1) - V(st)
delta = (
reward_trajs[t] * self.reward_scale_const
+ self.gamma * nextvalues * nonterminal
- values_trajs[t]
)
# A = delta_t + gamma*lamdba*delta_{t+1} + ...
advantages_trajs[t] = lastgaelam = (
delta
+ self.gamma * self.gae_lambda * nonterminal * lastgaelam
)
returns_trajs = advantages_trajs + values_trajs
# k for environment step
obs_k = {
k: einops.rearrange(
obs_trajs[k],
"s e ... -> (s e) ...",
)
for k in obs_trajs
}
samples_k = einops.rearrange(
torch.tensor(samples_trajs, device=self.device).float(),
"s e h d -> (s e) h d",
)
returns_k = (
torch.tensor(returns_trajs, device=self.device).float().reshape(-1)
)
values_k = (
torch.tensor(values_trajs, device=self.device).float().reshape(-1)
)
advantages_k = (
torch.tensor(advantages_trajs, device=self.device).float().reshape(-1)
)
logprobs_k = torch.tensor(logprobs_trajs, device=self.device).float()
# Update policy and critic
total_steps = self.n_steps * self.n_envs
clipfracs = []
for update_epoch in range(self.update_epochs):
# for each epoch, go through all data in batches
flag_break = False
inds_k = torch.randperm(total_steps, device=self.device)
num_batch = max(1, total_steps // self.batch_size) # skip last ones
for batch in range(num_batch):
start = batch * self.batch_size
end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch
obs_b = {k: obs_k[k][inds_b] for k in obs_k}
samples_b = samples_k[inds_b]
returns_b = returns_k[inds_b]
values_b = values_k[inds_b]
advantages_b = advantages_k[inds_b]
logprobs_b = logprobs_k[inds_b]
# get loss
(
pg_loss,
entropy_loss,
v_loss,
clipfrac,
approx_kl,
ratio,
bc_loss,
std,
) = self.model.loss(
obs_b,
samples_b,
returns_b,
values_b,
advantages_b,
logprobs_b,
use_bc_loss=self.use_bc_loss,
)
loss = (
pg_loss
+ entropy_loss * self.ent_coef
+ v_loss * self.vf_coef
+ bc_loss * self.bc_loss_coeff
)
clipfracs += [clipfrac]
# update policy and critic
loss.backward()
if (batch + 1) % self.grad_accumulate == 0:
if self.itr >= self.n_critic_warmup_itr:
if self.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.model.actor_ft.parameters(),
self.max_grad_norm,
)
self.actor_optimizer.step()
self.critic_optimizer.step()
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
log.info(f"run grad update at batch {batch}")
log.info(
f"approx_kl: {approx_kl}, update_epoch: {update_epoch}, num_batch: {num_batch}"
)
# Stop gradient update if KL difference reaches target
if (
self.target_kl is not None
and approx_kl > self.target_kl
and self.itr >= self.n_critic_warmup_itr
):
flag_break = True
break
if flag_break:
break
# Explained variation of future rewards using value function
y_pred, y_true = values_k.cpu().numpy(), returns_k.cpu().numpy()
var_y = np.var(y_true)
explained_var = (
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
)
# Update lr
if self.itr >= self.n_critic_warmup_itr:
self.actor_lr_scheduler.step()
self.critic_lr_scheduler.step()
# Save model
if self.itr % self.save_model_freq == 0 or self.itr == self.n_train_itr - 1:
self.save_model()
# Log loss and save metrics
run_results.append(
{
"itr": self.itr,
"step": cnt_train_step,
}
)
if self.itr % self.log_freq == 0:
time = timer()
run_results[-1]["time"] = time
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
)
if self.use_wandb:
wandb.log(
{
"success rate - eval": success_rate,
"avg episode reward - eval": avg_episode_reward,
"avg best reward - eval": avg_best_reward,
"num episode - eval": num_episode_finished,
},
step=self.itr,
commit=False,
)
run_results[-1]["eval_success_rate"] = success_rate
run_results[-1]["eval_episode_reward"] = avg_episode_reward
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: step {cnt_train_step:8d} | loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
{
"total env step": cnt_train_step,
"loss": loss,
"pg loss": pg_loss,
"value loss": v_loss,
"bc loss": bc_loss,
"std": std,
"approx kl": approx_kl,
"ratio": ratio,
"clipfrac": np.mean(clipfracs),
"explained variance": explained_var,
"avg episode reward - train": avg_episode_reward,
"num episode - train": num_episode_finished,
"actor lr": self.actor_optimizer.param_groups[0]["lr"],
"critic lr": self.critic_optimizer.param_groups[0][
"lr"
],
},
step=self.itr,
commit=True,
)
run_results[-1]["train_episode_reward"] = avg_episode_reward
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1