* Sampling over both env and denoising steps in DPPO updates (#13) * sample one from each chain * full random sampling * Add Proficient Human (PH) Configs and Pipeline (#16) * fix missing cfg * add ph config * fix how terminated flags are added to buffer in ibrl * add ph config * offline calql for 1M gradient updates * bug fix: number of calql online gradient steps is the number of new transitions collected * add sample config for DPPO with ta=1 * Sampling over both env and denoising steps in DPPO updates (#13) * sample one from each chain * full random sampling * fix diffusion loss when predicting initial noise * fix dppo inds * fix typo * remove print statement --------- Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu> Co-authored-by: allenzren <allen.ren@princeton.edu> * update robomimic configs * better calql formulation * optimize calql and ibrl training * optimize data transfer in ppo agents * add kitchen configs * re-organize config folders, rerun calql and rlpd * add scratch gym locomotion configs * add kitchen installation dependencies * use truncated for termination in furniture env * update furniture and gym configs * update README and dependencies with kitchen * add url for new data and checkpoints * update demo RL configs * update batch sizes for furniture unet configs * raise error about dropout in residual mlp * fix observation bug in bc loss --------- Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com> Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
200 lines
7.0 KiB
Python
200 lines
7.0 KiB
Python
"""
|
||
DPPO: Diffusion Policy Policy Optimization.
|
||
|
||
K: number of denoising steps
|
||
To: observation sequence length
|
||
Ta: action chunk size
|
||
Do: observation dimension
|
||
Da: action dimension
|
||
|
||
C: image channels
|
||
H, W: image height and width
|
||
|
||
"""
|
||
|
||
from typing import Optional
|
||
import torch
|
||
import logging
|
||
import math
|
||
|
||
log = logging.getLogger(__name__)
|
||
from model.diffusion.diffusion_vpg import VPGDiffusion
|
||
|
||
|
||
class PPODiffusion(VPGDiffusion):
|
||
def __init__(
|
||
self,
|
||
gamma_denoising: float,
|
||
clip_ploss_coef: float,
|
||
clip_ploss_coef_base: float = 1e-3,
|
||
clip_ploss_coef_rate: float = 3,
|
||
clip_vloss_coef: Optional[float] = None,
|
||
clip_advantage_lower_quantile: float = 0,
|
||
clip_advantage_upper_quantile: float = 1,
|
||
norm_adv: bool = True,
|
||
**kwargs,
|
||
):
|
||
super().__init__(**kwargs)
|
||
|
||
# Whether to normalize advantages within batch
|
||
self.norm_adv = norm_adv
|
||
|
||
# Clipping value for policy loss
|
||
self.clip_ploss_coef = clip_ploss_coef
|
||
self.clip_ploss_coef_base = clip_ploss_coef_base
|
||
self.clip_ploss_coef_rate = clip_ploss_coef_rate
|
||
|
||
# Clipping value for value loss
|
||
self.clip_vloss_coef = clip_vloss_coef
|
||
|
||
# Discount factor for diffusion MDP
|
||
self.gamma_denoising = gamma_denoising
|
||
|
||
# Quantiles for clipping advantages
|
||
self.clip_advantage_lower_quantile = clip_advantage_lower_quantile
|
||
self.clip_advantage_upper_quantile = clip_advantage_upper_quantile
|
||
|
||
def loss(
|
||
self,
|
||
obs,
|
||
chains_prev,
|
||
chains_next,
|
||
denoising_inds,
|
||
returns,
|
||
oldvalues,
|
||
advantages,
|
||
oldlogprobs,
|
||
use_bc_loss=False,
|
||
reward_horizon=4,
|
||
):
|
||
"""
|
||
PPO loss
|
||
|
||
obs: dict with key state/rgb; more recent obs at the end
|
||
state: (B, To, Do)
|
||
rgb: (B, To, C, H, W)
|
||
chains: (B, K+1, Ta, Da)
|
||
returns: (B, )
|
||
values: (B, )
|
||
advantages: (B,)
|
||
oldlogprobs: (B, K, Ta, Da)
|
||
use_bc_loss: whether to add BC regularization loss
|
||
reward_horizon: action horizon that backpropagates gradient
|
||
"""
|
||
# Get new logprobs for denoising steps from T-1 to 0 - entropy is fixed fod diffusion
|
||
newlogprobs, eta = self.get_logprobs_subsample(
|
||
obs,
|
||
chains_prev,
|
||
chains_next,
|
||
denoising_inds,
|
||
get_ent=True,
|
||
)
|
||
entropy_loss = -eta.mean()
|
||
newlogprobs = newlogprobs.clamp(min=-5, max=2)
|
||
oldlogprobs = oldlogprobs.clamp(min=-5, max=2)
|
||
|
||
# only backpropagate through the earlier steps (e.g., ones actually executed in the environment)
|
||
newlogprobs = newlogprobs[:, :reward_horizon, :]
|
||
oldlogprobs = oldlogprobs[:, :reward_horizon, :]
|
||
|
||
# Get the logprobs - batch over B and denoising steps
|
||
newlogprobs = newlogprobs.mean(dim=(-1, -2)).view(-1)
|
||
oldlogprobs = oldlogprobs.mean(dim=(-1, -2)).view(-1)
|
||
|
||
bc_loss = 0
|
||
if use_bc_loss:
|
||
# See Eqn. 2 of https://arxiv.org/pdf/2403.03949.pdf
|
||
# Give a reward for maximizing probability of teacher policy's action with current policy.
|
||
# Actions are chosen along trajectory induced by current policy.
|
||
|
||
# Get counterfactual teacher actions
|
||
samples = self.forward(
|
||
cond=obs,
|
||
deterministic=False,
|
||
return_chain=True,
|
||
use_base_policy=True,
|
||
)
|
||
# Get logprobs of teacher actions under this policy
|
||
bc_logprobs = self.get_logprobs(
|
||
obs,
|
||
samples.chains,
|
||
get_ent=False,
|
||
use_base_policy=False,
|
||
)
|
||
bc_logprobs = bc_logprobs.clamp(min=-5, max=2)
|
||
bc_logprobs = bc_logprobs.mean(dim=(-1, -2)).view(-1)
|
||
bc_loss = -bc_logprobs.mean()
|
||
|
||
# normalize advantages
|
||
if self.norm_adv:
|
||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||
|
||
# Clip advantages by 5th and 95th percentile
|
||
advantage_min = torch.quantile(advantages, self.clip_advantage_lower_quantile)
|
||
advantage_max = torch.quantile(advantages, self.clip_advantage_upper_quantile)
|
||
advantages = advantages.clamp(min=advantage_min, max=advantage_max)
|
||
|
||
# denoising discount
|
||
discount = torch.tensor(
|
||
[
|
||
self.gamma_denoising ** (self.ft_denoising_steps - i - 1)
|
||
for i in denoising_inds
|
||
]
|
||
).to(self.device)
|
||
advantages *= discount
|
||
|
||
# get ratio
|
||
logratio = newlogprobs - oldlogprobs
|
||
ratio = logratio.exp()
|
||
|
||
# exponentially interpolate between the base and the current clipping value over denoising steps and repeat
|
||
t = (denoising_inds.float() / (self.ft_denoising_steps - 1)).to(self.device)
|
||
if self.ft_denoising_steps > 1:
|
||
clip_ploss_coef = self.clip_ploss_coef_base + (
|
||
self.clip_ploss_coef - self.clip_ploss_coef_base
|
||
) * (torch.exp(self.clip_ploss_coef_rate * t) - 1) / (
|
||
math.exp(self.clip_ploss_coef_rate) - 1
|
||
)
|
||
else:
|
||
clip_ploss_coef = t
|
||
|
||
# get kl difference and whether value clipped
|
||
with torch.no_grad():
|
||
# old_approx_kl: the approximate Kullback–Leibler divergence, measured by (-logratio).mean(), which corresponds to the k1 estimator in John Schulman’s blog post on approximating KL http://joschu.net/blog/kl-approx.html
|
||
# approx_kl: better alternative to old_approx_kl measured by (logratio.exp() - 1) - logratio, which corresponds to the k3 estimator in approximating KL http://joschu.net/blog/kl-approx.html
|
||
# old_approx_kl = (-logratio).mean()
|
||
approx_kl = ((ratio - 1) - logratio).mean()
|
||
clipfrac = ((ratio - 1.0).abs() > clip_ploss_coef).float().mean().item()
|
||
|
||
# Policy loss with clipping
|
||
pg_loss1 = -advantages * ratio
|
||
pg_loss2 = -advantages * torch.clamp(
|
||
ratio, 1 - clip_ploss_coef, 1 + clip_ploss_coef
|
||
)
|
||
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
|
||
|
||
# Value loss optionally with clipping
|
||
newvalues = self.critic(obs).view(-1)
|
||
if self.clip_vloss_coef is not None:
|
||
v_loss_unclipped = (newvalues - returns) ** 2
|
||
v_clipped = oldvalues + torch.clamp(
|
||
newvalues - oldvalues,
|
||
-self.clip_vloss_coef,
|
||
self.clip_vloss_coef,
|
||
)
|
||
v_loss_clipped = (v_clipped - returns) ** 2
|
||
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
|
||
v_loss = 0.5 * v_loss_max.mean()
|
||
else:
|
||
v_loss = 0.5 * ((newvalues - returns) ** 2).mean()
|
||
return (
|
||
pg_loss,
|
||
entropy_loss,
|
||
v_loss,
|
||
clipfrac,
|
||
approx_kl.item(),
|
||
ratio.mean().item(),
|
||
bc_loss,
|
||
eta.mean().item(),
|
||
)
|