* Sampling over both env and denoising steps in DPPO updates (#13) * sample one from each chain * full random sampling * Add Proficient Human (PH) Configs and Pipeline (#16) * fix missing cfg * add ph config * fix how terminated flags are added to buffer in ibrl * add ph config * offline calql for 1M gradient updates * bug fix: number of calql online gradient steps is the number of new transitions collected * add sample config for DPPO with ta=1 * Sampling over both env and denoising steps in DPPO updates (#13) * sample one from each chain * full random sampling * fix diffusion loss when predicting initial noise * fix dppo inds * fix typo * remove print statement --------- Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu> Co-authored-by: allenzren <allen.ren@princeton.edu> * update robomimic configs * better calql formulation * optimize calql and ibrl training * optimize data transfer in ppo agents * add kitchen configs * re-organize config folders, rerun calql and rlpd * add scratch gym locomotion configs * add kitchen installation dependencies * use truncated for termination in furniture env * update furniture and gym configs * update README and dependencies with kitchen * add url for new data and checkpoints * update demo RL configs * update batch sizes for furniture unet configs * raise error about dropout in residual mlp * fix observation bug in bc loss --------- Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com> Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
406 lines
18 KiB
Python
406 lines
18 KiB
Python
"""
|
|
Use diffusion exact likelihood for policy gradient.
|
|
|
|
Do not support pixel input yet.
|
|
|
|
"""
|
|
|
|
import os
|
|
import pickle
|
|
import einops
|
|
import numpy as np
|
|
import torch
|
|
import logging
|
|
import wandb
|
|
import math
|
|
|
|
log = logging.getLogger(__name__)
|
|
from util.timer import Timer
|
|
from agent.finetune.train_ppo_diffusion_agent import TrainPPODiffusionAgent
|
|
|
|
|
|
class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|
|
|
def __init__(self, cfg):
|
|
super().__init__(cfg)
|
|
|
|
def run(self):
|
|
"""
|
|
For exact likelihood, we do not need to save the chains.
|
|
"""
|
|
|
|
# Start training loop
|
|
timer = Timer()
|
|
run_results = []
|
|
cnt_train_step = 0
|
|
last_itr_eval = False
|
|
done_venv = np.zeros((1, self.n_envs))
|
|
while self.itr < self.n_train_itr:
|
|
|
|
# Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
|
|
options_venv = [{} for _ in range(self.n_envs)]
|
|
if self.itr % self.render_freq == 0 and self.render_video:
|
|
for env_ind in range(self.n_render):
|
|
options_venv[env_ind]["video_path"] = os.path.join(
|
|
self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
|
|
)
|
|
|
|
# Define train or eval - all envs restart
|
|
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
|
self.model.eval() if eval_mode else self.model.train()
|
|
last_itr_eval = eval_mode
|
|
|
|
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
|
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
|
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
|
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
|
firsts_trajs[0] = 1
|
|
else:
|
|
# if done at the end of last iteration, the envs are just reset
|
|
firsts_trajs[0] = done_venv
|
|
|
|
# Holder
|
|
obs_trajs = {
|
|
"state": np.zeros(
|
|
(self.n_steps, self.n_envs, self.n_cond_step, self.obs_dim)
|
|
)
|
|
}
|
|
samples_trajs = np.zeros(
|
|
(
|
|
self.n_steps,
|
|
self.n_envs,
|
|
self.horizon_steps,
|
|
self.action_dim,
|
|
)
|
|
)
|
|
chains_trajs = np.zeros(
|
|
(
|
|
self.n_steps,
|
|
self.n_envs,
|
|
self.model.ft_denoising_steps + 1,
|
|
self.horizon_steps,
|
|
self.action_dim,
|
|
)
|
|
)
|
|
terminated_trajs = np.zeros((self.n_steps, self.n_envs))
|
|
reward_trajs = np.zeros((self.n_steps, self.n_envs))
|
|
|
|
# Collect a set of trajectories from env
|
|
for step in range(self.n_steps):
|
|
if step % 10 == 0:
|
|
print(f"Processed step {step} of {self.n_steps}")
|
|
|
|
# Select action
|
|
with torch.no_grad():
|
|
cond = {
|
|
"state": torch.from_numpy(prev_obs_venv["state"])
|
|
.float()
|
|
.to(self.device)
|
|
}
|
|
samples = self.model(
|
|
cond=cond,
|
|
deterministic=eval_mode,
|
|
return_chain=True,
|
|
)
|
|
output_venv = (
|
|
samples.trajectories.cpu().numpy()
|
|
) # n_env x horizon x act
|
|
chains_venv = (
|
|
samples.chains.cpu().numpy()
|
|
) # n_env x denoising x horizon x act
|
|
action_venv = output_venv[:, : self.act_steps]
|
|
samples_trajs[step] = output_venv
|
|
|
|
# Apply multi-step action
|
|
obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
|
|
self.venv.step(action_venv)
|
|
)
|
|
done_venv = terminated_venv | truncated_venv
|
|
obs_trajs["state"][step] = prev_obs_venv["state"]
|
|
chains_trajs[step] = chains_venv
|
|
reward_trajs[step] = reward_venv
|
|
terminated_trajs[step] = terminated_venv
|
|
firsts_trajs[step + 1] = done_venv
|
|
|
|
# update for next step
|
|
prev_obs_venv = obs_venv
|
|
|
|
# count steps --- not acounting for done within action chunk
|
|
cnt_train_step += self.n_envs * self.act_steps if not eval_mode else 0
|
|
|
|
# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
|
|
episodes_start_end = []
|
|
for env_ind in range(self.n_envs):
|
|
env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
|
|
for i in range(len(env_steps) - 1):
|
|
start = env_steps[i]
|
|
end = env_steps[i + 1]
|
|
if end - start > 1:
|
|
episodes_start_end.append((env_ind, start, end - 1))
|
|
if len(episodes_start_end) > 0:
|
|
reward_trajs_split = [
|
|
reward_trajs[start : end + 1, env_ind]
|
|
for env_ind, start, end in episodes_start_end
|
|
]
|
|
num_episode_finished = len(reward_trajs_split)
|
|
episode_reward = np.array(
|
|
[np.sum(reward_traj) for reward_traj in reward_trajs_split]
|
|
)
|
|
episode_best_reward = np.array(
|
|
[
|
|
np.max(reward_traj) / self.act_steps
|
|
for reward_traj in reward_trajs_split
|
|
]
|
|
)
|
|
avg_episode_reward = np.mean(episode_reward)
|
|
avg_best_reward = np.mean(episode_best_reward)
|
|
success_rate = np.mean(
|
|
episode_best_reward >= self.best_reward_threshold_for_success
|
|
)
|
|
else:
|
|
episode_reward = np.array([])
|
|
num_episode_finished = 0
|
|
avg_episode_reward = 0
|
|
avg_best_reward = 0
|
|
success_rate = 0
|
|
log.info("[WARNING] No episode completed within the iteration!")
|
|
|
|
# Update models
|
|
if not eval_mode:
|
|
with torch.no_grad():
|
|
obs_trajs["state"] = (
|
|
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
|
|
)
|
|
|
|
# Calculate value and logprobs - split into batches to prevent out of memory
|
|
num_split = math.ceil(
|
|
self.n_envs * self.n_steps / self.logprob_batch_size
|
|
)
|
|
obs_ts = [{} for _ in range(num_split)]
|
|
obs_k = einops.rearrange(
|
|
obs_trajs["state"],
|
|
"s e ... -> (s e) ...",
|
|
)
|
|
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
|
|
for i, obs_t in enumerate(obs_ts_k):
|
|
obs_ts[i]["state"] = obs_t
|
|
values_trajs = np.empty((0, self.n_envs))
|
|
for obs in obs_ts:
|
|
values = self.model.critic(obs).cpu().numpy().flatten()
|
|
values_trajs = np.vstack(
|
|
(values_trajs, values.reshape(-1, self.n_envs))
|
|
)
|
|
samples_t = einops.rearrange(
|
|
torch.from_numpy(samples_trajs).float().to(self.device),
|
|
"s e h d -> (s e) h d",
|
|
)
|
|
samples_ts = torch.split(samples_t, self.logprob_batch_size, dim=0)
|
|
logprobs_trajs = np.empty((0))
|
|
for obs, samples in zip(obs_ts, samples_ts):
|
|
logprobs = (
|
|
self.model.get_exact_logprobs(obs, samples).cpu().numpy()
|
|
)
|
|
logprobs_trajs = np.concatenate((logprobs_trajs, logprobs))
|
|
|
|
# normalize reward with running variance if specified
|
|
if self.reward_scale_running:
|
|
reward_trajs_transpose = self.running_reward_scaler(
|
|
reward=reward_trajs.T, first=firsts_trajs[:-1].T
|
|
)
|
|
reward_trajs = reward_trajs_transpose.T
|
|
|
|
# bootstrap value with GAE if not terminal - apply reward scaling with constant if specified
|
|
obs_venv_ts = {
|
|
"state": torch.from_numpy(obs_venv["state"])
|
|
.float()
|
|
.to(self.device)
|
|
}
|
|
advantages_trajs = np.zeros_like(reward_trajs)
|
|
lastgaelam = 0
|
|
for t in reversed(range(self.n_steps)):
|
|
if t == self.n_steps - 1:
|
|
nextvalues = (
|
|
self.model.critic(obs_venv_ts)
|
|
.reshape(1, -1)
|
|
.cpu()
|
|
.numpy()
|
|
)
|
|
else:
|
|
nextvalues = values_trajs[t + 1]
|
|
nonterminal = 1.0 - terminated_trajs[t]
|
|
# delta = r + gamma*V(st+1) - V(st)
|
|
delta = (
|
|
reward_trajs[t] * self.reward_scale_const
|
|
+ self.gamma * nextvalues * nonterminal
|
|
- values_trajs[t]
|
|
)
|
|
# A = delta_t + gamma*lamdba*delta_{t+1} + ...
|
|
advantages_trajs[t] = lastgaelam = (
|
|
delta
|
|
+ self.gamma * self.gae_lambda * nonterminal * lastgaelam
|
|
)
|
|
returns_trajs = advantages_trajs + values_trajs
|
|
|
|
# k for environment step
|
|
obs_k = {
|
|
"state": einops.rearrange(
|
|
obs_trajs["state"],
|
|
"s e ... -> (s e) ...",
|
|
)
|
|
}
|
|
samples_k = einops.rearrange(
|
|
torch.tensor(samples_trajs, device=self.device).float(),
|
|
"s e h d -> (s e) h d",
|
|
)
|
|
returns_k = (
|
|
torch.tensor(returns_trajs, device=self.device).float().reshape(-1)
|
|
)
|
|
values_k = (
|
|
torch.tensor(values_trajs, device=self.device).float().reshape(-1)
|
|
)
|
|
advantages_k = (
|
|
torch.tensor(advantages_trajs, device=self.device).float().reshape(-1)
|
|
)
|
|
logprobs_k = torch.tensor(logprobs_trajs, device=self.device).float()
|
|
|
|
# Update policy and critic
|
|
total_steps = self.n_steps * self.n_envs
|
|
clipfracs = []
|
|
for update_epoch in range(self.update_epochs):
|
|
|
|
# for each epoch, go through all data in batches
|
|
flag_break = False
|
|
inds_k = torch.randperm(total_steps, device=self.device)
|
|
num_batch = max(1, total_steps // self.batch_size) # skip last ones
|
|
for batch in range(num_batch):
|
|
start = batch * self.batch_size
|
|
end = start + self.batch_size
|
|
inds_b = inds_k[start:end] # b for batch
|
|
obs_b = {"state": obs_k["state"][inds_b]}
|
|
samples_b = samples_k[inds_b]
|
|
returns_b = returns_k[inds_b]
|
|
values_b = values_k[inds_b]
|
|
advantages_b = advantages_k[inds_b]
|
|
logprobs_b = logprobs_k[inds_b]
|
|
|
|
# get loss
|
|
(
|
|
pg_loss,
|
|
v_loss,
|
|
clipfrac,
|
|
approx_kl,
|
|
ratio,
|
|
bc_loss,
|
|
) = self.model.loss(
|
|
obs_b,
|
|
samples_b,
|
|
returns_b,
|
|
values_b,
|
|
advantages_b,
|
|
logprobs_b,
|
|
use_bc_loss=self.use_bc_loss,
|
|
reward_horizon=self.reward_horizon,
|
|
)
|
|
loss = (
|
|
pg_loss
|
|
+ v_loss * self.vf_coef
|
|
+ bc_loss * self.bc_loss_coeff
|
|
)
|
|
clipfracs += [clipfrac]
|
|
|
|
# update policy and critic
|
|
self.actor_optimizer.zero_grad()
|
|
self.critic_optimizer.zero_grad()
|
|
loss.backward()
|
|
if self.itr >= self.n_critic_warmup_itr:
|
|
if self.max_grad_norm is not None:
|
|
torch.nn.utils.clip_grad_norm_(
|
|
self.model.actor_ft.parameters(), self.max_grad_norm
|
|
)
|
|
self.actor_optimizer.step()
|
|
self.critic_optimizer.step()
|
|
log.info(
|
|
f"approx_kl: {approx_kl}, update_epoch: {update_epoch}, num_batch: {num_batch}"
|
|
)
|
|
|
|
# Stop gradient update if KL difference reaches target
|
|
if self.target_kl is not None and approx_kl > self.target_kl:
|
|
flag_break = True
|
|
break
|
|
if flag_break:
|
|
break
|
|
|
|
# Explained variation of future rewards using value function
|
|
y_pred, y_true = values_k.cpu().numpy(), returns_k.cpu().numpy()
|
|
var_y = np.var(y_true)
|
|
explained_var = (
|
|
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
|
)
|
|
|
|
# Update lr
|
|
if self.itr >= self.n_critic_warmup_itr:
|
|
self.actor_lr_scheduler.step()
|
|
self.critic_lr_scheduler.step()
|
|
|
|
# Save model
|
|
if self.itr % self.save_model_freq == 0 or self.itr == self.n_train_itr - 1:
|
|
self.save_model()
|
|
|
|
# Log loss and save metrics
|
|
run_results.append(
|
|
{
|
|
"itr": self.itr,
|
|
"step": cnt_train_step,
|
|
}
|
|
)
|
|
if self.save_trajs:
|
|
run_results[-1]["obs_trajs"] = obs_trajs
|
|
run_results[-1]["action_trajs"] = samples_trajs
|
|
run_results[-1]["chains_trajs"] = chains_trajs
|
|
run_results[-1]["reward_trajs"] = reward_trajs
|
|
if self.itr % self.log_freq == 0:
|
|
time = timer()
|
|
run_results[-1]["time"] = time
|
|
if eval_mode:
|
|
log.info(
|
|
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
|
)
|
|
if self.use_wandb:
|
|
wandb.log(
|
|
{
|
|
"success rate - eval": success_rate,
|
|
"avg episode reward - eval": avg_episode_reward,
|
|
"avg best reward - eval": avg_best_reward,
|
|
"num episode - eval": num_episode_finished,
|
|
},
|
|
step=self.itr,
|
|
commit=False,
|
|
)
|
|
run_results[-1]["eval_success_rate"] = success_rate
|
|
run_results[-1]["eval_episode_reward"] = avg_episode_reward
|
|
run_results[-1]["eval_best_reward"] = avg_best_reward
|
|
else:
|
|
log.info(
|
|
f"{self.itr}: step {cnt_train_step:8d} | loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
|
|
)
|
|
if self.use_wandb:
|
|
wandb.log(
|
|
{
|
|
"total env step": cnt_train_step,
|
|
"loss": loss,
|
|
"pg loss": pg_loss,
|
|
"value loss": v_loss,
|
|
"approx kl": approx_kl,
|
|
"ratio": ratio,
|
|
"clipfrac": np.mean(clipfracs),
|
|
"explained variance": explained_var,
|
|
"avg episode reward - train": avg_episode_reward,
|
|
},
|
|
step=self.itr,
|
|
commit=True,
|
|
)
|
|
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
|
with open(self.result_path, "wb") as f:
|
|
pickle.dump(run_results, f)
|
|
self.itr += 1
|