dppo/agent/finetune/train_ppo_exact_diffusion_agent.py
Allen Z. Ren dc8e0c9edc
v0.6 (#18)
* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* Add Proficient Human (PH) Configs and Pipeline (#16)

* fix missing cfg

* add ph config

* fix how terminated flags are added to buffer in ibrl

* add ph config

* offline calql for 1M gradient updates

* bug fix: number of calql online gradient steps is the number of new transitions collected

* add sample config for DPPO with ta=1

* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* fix diffusion loss when predicting initial noise

* fix dppo inds

* fix typo

* remove print statement

---------

Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
Co-authored-by: allenzren <allen.ren@princeton.edu>

* update robomimic configs

* better calql formulation

* optimize calql and ibrl training

* optimize data transfer in ppo agents

* add kitchen configs

* re-organize config folders, rerun calql and rlpd

* add scratch gym locomotion configs

* add kitchen installation dependencies

* use truncated for termination in furniture env

* update furniture and gym configs

* update README and dependencies with kitchen

* add url for new data and checkpoints

* update demo RL configs

* update batch sizes for furniture unet configs

* raise error about dropout in residual mlp

* fix observation bug in bc loss

---------

Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com>
Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
2024-10-30 19:58:06 -04:00

406 lines
18 KiB
Python

"""
Use diffusion exact likelihood for policy gradient.
Do not support pixel input yet.
"""
import os
import pickle
import einops
import numpy as np
import torch
import logging
import wandb
import math
log = logging.getLogger(__name__)
from util.timer import Timer
from agent.finetune.train_ppo_diffusion_agent import TrainPPODiffusionAgent
class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
def __init__(self, cfg):
super().__init__(cfg)
def run(self):
"""
For exact likelihood, we do not need to save the chains.
"""
# Start training loop
timer = Timer()
run_results = []
cnt_train_step = 0
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr:
# Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
options_venv = [{} for _ in range(self.n_envs)]
if self.itr % self.render_freq == 0 and self.render_video:
for env_ind in range(self.n_render):
options_venv[env_ind]["video_path"] = os.path.join(
self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
)
# Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train()
last_itr_eval = eval_mode
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
else:
# if done at the end of last iteration, the envs are just reset
firsts_trajs[0] = done_venv
# Holder
obs_trajs = {
"state": np.zeros(
(self.n_steps, self.n_envs, self.n_cond_step, self.obs_dim)
)
}
samples_trajs = np.zeros(
(
self.n_steps,
self.n_envs,
self.horizon_steps,
self.action_dim,
)
)
chains_trajs = np.zeros(
(
self.n_steps,
self.n_envs,
self.model.ft_denoising_steps + 1,
self.horizon_steps,
self.action_dim,
)
)
terminated_trajs = np.zeros((self.n_steps, self.n_envs))
reward_trajs = np.zeros((self.n_steps, self.n_envs))
# Collect a set of trajectories from env
for step in range(self.n_steps):
if step % 10 == 0:
print(f"Processed step {step} of {self.n_steps}")
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = self.model(
cond=cond,
deterministic=eval_mode,
return_chain=True,
)
output_venv = (
samples.trajectories.cpu().numpy()
) # n_env x horizon x act
chains_venv = (
samples.chains.cpu().numpy()
) # n_env x denoising x horizon x act
action_venv = output_venv[:, : self.act_steps]
samples_trajs[step] = output_venv
# Apply multi-step action
obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
self.venv.step(action_venv)
)
done_venv = terminated_venv | truncated_venv
obs_trajs["state"][step] = prev_obs_venv["state"]
chains_trajs[step] = chains_venv
reward_trajs[step] = reward_venv
terminated_trajs[step] = terminated_venv
firsts_trajs[step + 1] = done_venv
# update for next step
prev_obs_venv = obs_venv
# count steps --- not acounting for done within action chunk
cnt_train_step += self.n_envs * self.act_steps if not eval_mode else 0
# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
episodes_start_end = []
for env_ind in range(self.n_envs):
env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
for i in range(len(env_steps) - 1):
start = env_steps[i]
end = env_steps[i + 1]
if end - start > 1:
episodes_start_end.append((env_ind, start, end - 1))
if len(episodes_start_end) > 0:
reward_trajs_split = [
reward_trajs[start : end + 1, env_ind]
for env_ind, start, end in episodes_start_end
]
num_episode_finished = len(reward_trajs_split)
episode_reward = np.array(
[np.sum(reward_traj) for reward_traj in reward_trajs_split]
)
episode_best_reward = np.array(
[
np.max(reward_traj) / self.act_steps
for reward_traj in reward_trajs_split
]
)
avg_episode_reward = np.mean(episode_reward)
avg_best_reward = np.mean(episode_best_reward)
success_rate = np.mean(
episode_best_reward >= self.best_reward_threshold_for_success
)
else:
episode_reward = np.array([])
num_episode_finished = 0
avg_episode_reward = 0
avg_best_reward = 0
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update models
if not eval_mode:
with torch.no_grad():
obs_trajs["state"] = (
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
)
# Calculate value and logprobs - split into batches to prevent out of memory
num_split = math.ceil(
self.n_envs * self.n_steps / self.logprob_batch_size
)
obs_ts = [{} for _ in range(num_split)]
obs_k = einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
for i, obs_t in enumerate(obs_ts_k):
obs_ts[i]["state"] = obs_t
values_trajs = np.empty((0, self.n_envs))
for obs in obs_ts:
values = self.model.critic(obs).cpu().numpy().flatten()
values_trajs = np.vstack(
(values_trajs, values.reshape(-1, self.n_envs))
)
samples_t = einops.rearrange(
torch.from_numpy(samples_trajs).float().to(self.device),
"s e h d -> (s e) h d",
)
samples_ts = torch.split(samples_t, self.logprob_batch_size, dim=0)
logprobs_trajs = np.empty((0))
for obs, samples in zip(obs_ts, samples_ts):
logprobs = (
self.model.get_exact_logprobs(obs, samples).cpu().numpy()
)
logprobs_trajs = np.concatenate((logprobs_trajs, logprobs))
# normalize reward with running variance if specified
if self.reward_scale_running:
reward_trajs_transpose = self.running_reward_scaler(
reward=reward_trajs.T, first=firsts_trajs[:-1].T
)
reward_trajs = reward_trajs_transpose.T
# bootstrap value with GAE if not terminal - apply reward scaling with constant if specified
obs_venv_ts = {
"state": torch.from_numpy(obs_venv["state"])
.float()
.to(self.device)
}
advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0
for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1:
nextvalues = (
self.model.critic(obs_venv_ts)
.reshape(1, -1)
.cpu()
.numpy()
)
else:
nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - terminated_trajs[t]
# delta = r + gamma*V(st+1) - V(st)
delta = (
reward_trajs[t] * self.reward_scale_const
+ self.gamma * nextvalues * nonterminal
- values_trajs[t]
)
# A = delta_t + gamma*lamdba*delta_{t+1} + ...
advantages_trajs[t] = lastgaelam = (
delta
+ self.gamma * self.gae_lambda * nonterminal * lastgaelam
)
returns_trajs = advantages_trajs + values_trajs
# k for environment step
obs_k = {
"state": einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
}
samples_k = einops.rearrange(
torch.tensor(samples_trajs, device=self.device).float(),
"s e h d -> (s e) h d",
)
returns_k = (
torch.tensor(returns_trajs, device=self.device).float().reshape(-1)
)
values_k = (
torch.tensor(values_trajs, device=self.device).float().reshape(-1)
)
advantages_k = (
torch.tensor(advantages_trajs, device=self.device).float().reshape(-1)
)
logprobs_k = torch.tensor(logprobs_trajs, device=self.device).float()
# Update policy and critic
total_steps = self.n_steps * self.n_envs
clipfracs = []
for update_epoch in range(self.update_epochs):
# for each epoch, go through all data in batches
flag_break = False
inds_k = torch.randperm(total_steps, device=self.device)
num_batch = max(1, total_steps // self.batch_size) # skip last ones
for batch in range(num_batch):
start = batch * self.batch_size
end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch
obs_b = {"state": obs_k["state"][inds_b]}
samples_b = samples_k[inds_b]
returns_b = returns_k[inds_b]
values_b = values_k[inds_b]
advantages_b = advantages_k[inds_b]
logprobs_b = logprobs_k[inds_b]
# get loss
(
pg_loss,
v_loss,
clipfrac,
approx_kl,
ratio,
bc_loss,
) = self.model.loss(
obs_b,
samples_b,
returns_b,
values_b,
advantages_b,
logprobs_b,
use_bc_loss=self.use_bc_loss,
reward_horizon=self.reward_horizon,
)
loss = (
pg_loss
+ v_loss * self.vf_coef
+ bc_loss * self.bc_loss_coeff
)
clipfracs += [clipfrac]
# update policy and critic
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
loss.backward()
if self.itr >= self.n_critic_warmup_itr:
if self.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.model.actor_ft.parameters(), self.max_grad_norm
)
self.actor_optimizer.step()
self.critic_optimizer.step()
log.info(
f"approx_kl: {approx_kl}, update_epoch: {update_epoch}, num_batch: {num_batch}"
)
# Stop gradient update if KL difference reaches target
if self.target_kl is not None and approx_kl > self.target_kl:
flag_break = True
break
if flag_break:
break
# Explained variation of future rewards using value function
y_pred, y_true = values_k.cpu().numpy(), returns_k.cpu().numpy()
var_y = np.var(y_true)
explained_var = (
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
)
# Update lr
if self.itr >= self.n_critic_warmup_itr:
self.actor_lr_scheduler.step()
self.critic_lr_scheduler.step()
# Save model
if self.itr % self.save_model_freq == 0 or self.itr == self.n_train_itr - 1:
self.save_model()
# Log loss and save metrics
run_results.append(
{
"itr": self.itr,
"step": cnt_train_step,
}
)
if self.save_trajs:
run_results[-1]["obs_trajs"] = obs_trajs
run_results[-1]["action_trajs"] = samples_trajs
run_results[-1]["chains_trajs"] = chains_trajs
run_results[-1]["reward_trajs"] = reward_trajs
if self.itr % self.log_freq == 0:
time = timer()
run_results[-1]["time"] = time
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
)
if self.use_wandb:
wandb.log(
{
"success rate - eval": success_rate,
"avg episode reward - eval": avg_episode_reward,
"avg best reward - eval": avg_best_reward,
"num episode - eval": num_episode_finished,
},
step=self.itr,
commit=False,
)
run_results[-1]["eval_success_rate"] = success_rate
run_results[-1]["eval_episode_reward"] = avg_episode_reward
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: step {cnt_train_step:8d} | loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
{
"total env step": cnt_train_step,
"loss": loss,
"pg loss": pg_loss,
"value loss": v_loss,
"approx kl": approx_kl,
"ratio": ratio,
"clipfrac": np.mean(clipfracs),
"explained variance": explained_var,
"avg episode reward - train": avg_episode_reward,
},
step=self.itr,
commit=True,
)
run_results[-1]["train_episode_reward"] = avg_episode_reward
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1