dppo/model/diffusion/diffusion_ppo.py
Allen Z. Ren dc8e0c9edc
v0.6 (#18)
* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* Add Proficient Human (PH) Configs and Pipeline (#16)

* fix missing cfg

* add ph config

* fix how terminated flags are added to buffer in ibrl

* add ph config

* offline calql for 1M gradient updates

* bug fix: number of calql online gradient steps is the number of new transitions collected

* add sample config for DPPO with ta=1

* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* fix diffusion loss when predicting initial noise

* fix dppo inds

* fix typo

* remove print statement

---------

Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
Co-authored-by: allenzren <allen.ren@princeton.edu>

* update robomimic configs

* better calql formulation

* optimize calql and ibrl training

* optimize data transfer in ppo agents

* add kitchen configs

* re-organize config folders, rerun calql and rlpd

* add scratch gym locomotion configs

* add kitchen installation dependencies

* use truncated for termination in furniture env

* update furniture and gym configs

* update README and dependencies with kitchen

* add url for new data and checkpoints

* update demo RL configs

* update batch sizes for furniture unet configs

* raise error about dropout in residual mlp

* fix observation bug in bc loss

---------

Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com>
Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
2024-10-30 19:58:06 -04:00

200 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
DPPO: Diffusion Policy Policy Optimization.
K: number of denoising steps
To: observation sequence length
Ta: action chunk size
Do: observation dimension
Da: action dimension
C: image channels
H, W: image height and width
"""
from typing import Optional
import torch
import logging
import math
log = logging.getLogger(__name__)
from model.diffusion.diffusion_vpg import VPGDiffusion
class PPODiffusion(VPGDiffusion):
def __init__(
self,
gamma_denoising: float,
clip_ploss_coef: float,
clip_ploss_coef_base: float = 1e-3,
clip_ploss_coef_rate: float = 3,
clip_vloss_coef: Optional[float] = None,
clip_advantage_lower_quantile: float = 0,
clip_advantage_upper_quantile: float = 1,
norm_adv: bool = True,
**kwargs,
):
super().__init__(**kwargs)
# Whether to normalize advantages within batch
self.norm_adv = norm_adv
# Clipping value for policy loss
self.clip_ploss_coef = clip_ploss_coef
self.clip_ploss_coef_base = clip_ploss_coef_base
self.clip_ploss_coef_rate = clip_ploss_coef_rate
# Clipping value for value loss
self.clip_vloss_coef = clip_vloss_coef
# Discount factor for diffusion MDP
self.gamma_denoising = gamma_denoising
# Quantiles for clipping advantages
self.clip_advantage_lower_quantile = clip_advantage_lower_quantile
self.clip_advantage_upper_quantile = clip_advantage_upper_quantile
def loss(
self,
obs,
chains_prev,
chains_next,
denoising_inds,
returns,
oldvalues,
advantages,
oldlogprobs,
use_bc_loss=False,
reward_horizon=4,
):
"""
PPO loss
obs: dict with key state/rgb; more recent obs at the end
state: (B, To, Do)
rgb: (B, To, C, H, W)
chains: (B, K+1, Ta, Da)
returns: (B, )
values: (B, )
advantages: (B,)
oldlogprobs: (B, K, Ta, Da)
use_bc_loss: whether to add BC regularization loss
reward_horizon: action horizon that backpropagates gradient
"""
# Get new logprobs for denoising steps from T-1 to 0 - entropy is fixed fod diffusion
newlogprobs, eta = self.get_logprobs_subsample(
obs,
chains_prev,
chains_next,
denoising_inds,
get_ent=True,
)
entropy_loss = -eta.mean()
newlogprobs = newlogprobs.clamp(min=-5, max=2)
oldlogprobs = oldlogprobs.clamp(min=-5, max=2)
# only backpropagate through the earlier steps (e.g., ones actually executed in the environment)
newlogprobs = newlogprobs[:, :reward_horizon, :]
oldlogprobs = oldlogprobs[:, :reward_horizon, :]
# Get the logprobs - batch over B and denoising steps
newlogprobs = newlogprobs.mean(dim=(-1, -2)).view(-1)
oldlogprobs = oldlogprobs.mean(dim=(-1, -2)).view(-1)
bc_loss = 0
if use_bc_loss:
# See Eqn. 2 of https://arxiv.org/pdf/2403.03949.pdf
# Give a reward for maximizing probability of teacher policy's action with current policy.
# Actions are chosen along trajectory induced by current policy.
# Get counterfactual teacher actions
samples = self.forward(
cond=obs,
deterministic=False,
return_chain=True,
use_base_policy=True,
)
# Get logprobs of teacher actions under this policy
bc_logprobs = self.get_logprobs(
obs,
samples.chains,
get_ent=False,
use_base_policy=False,
)
bc_logprobs = bc_logprobs.clamp(min=-5, max=2)
bc_logprobs = bc_logprobs.mean(dim=(-1, -2)).view(-1)
bc_loss = -bc_logprobs.mean()
# normalize advantages
if self.norm_adv:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Clip advantages by 5th and 95th percentile
advantage_min = torch.quantile(advantages, self.clip_advantage_lower_quantile)
advantage_max = torch.quantile(advantages, self.clip_advantage_upper_quantile)
advantages = advantages.clamp(min=advantage_min, max=advantage_max)
# denoising discount
discount = torch.tensor(
[
self.gamma_denoising ** (self.ft_denoising_steps - i - 1)
for i in denoising_inds
]
).to(self.device)
advantages *= discount
# get ratio
logratio = newlogprobs - oldlogprobs
ratio = logratio.exp()
# exponentially interpolate between the base and the current clipping value over denoising steps and repeat
t = (denoising_inds.float() / (self.ft_denoising_steps - 1)).to(self.device)
if self.ft_denoising_steps > 1:
clip_ploss_coef = self.clip_ploss_coef_base + (
self.clip_ploss_coef - self.clip_ploss_coef_base
) * (torch.exp(self.clip_ploss_coef_rate * t) - 1) / (
math.exp(self.clip_ploss_coef_rate) - 1
)
else:
clip_ploss_coef = t
# get kl difference and whether value clipped
with torch.no_grad():
# old_approx_kl: the approximate KullbackLeibler divergence, measured by (-logratio).mean(), which corresponds to the k1 estimator in John Schulmans blog post on approximating KL http://joschu.net/blog/kl-approx.html
# approx_kl: better alternative to old_approx_kl measured by (logratio.exp() - 1) - logratio, which corresponds to the k3 estimator in approximating KL http://joschu.net/blog/kl-approx.html
# old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > clip_ploss_coef).float().mean().item()
# Policy loss with clipping
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(
ratio, 1 - clip_ploss_coef, 1 + clip_ploss_coef
)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# Value loss optionally with clipping
newvalues = self.critic(obs).view(-1)
if self.clip_vloss_coef is not None:
v_loss_unclipped = (newvalues - returns) ** 2
v_clipped = oldvalues + torch.clamp(
newvalues - oldvalues,
-self.clip_vloss_coef,
self.clip_vloss_coef,
)
v_loss_clipped = (v_clipped - returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((newvalues - returns) ** 2).mean()
return (
pg_loss,
entropy_loss,
v_loss,
clipfrac,
approx_kl.item(),
ratio.mean().item(),
bc_loss,
eta.mean().item(),
)