""" DPPO: Diffusion Policy Policy Optimization. K: number of denoising steps To: observation sequence length Ta: action chunk size Do: observation dimension Da: action dimension C: image channels H, W: image height and width """ from typing import Optional import torch import logging import math log = logging.getLogger(__name__) from model.diffusion.diffusion_vpg import VPGDiffusion class PPODiffusion(VPGDiffusion): def __init__( self, gamma_denoising: float, clip_ploss_coef: float, clip_ploss_coef_base: float = 1e-3, clip_ploss_coef_rate: float = 3, clip_vloss_coef: Optional[float] = None, clip_advantage_lower_quantile: float = 0, clip_advantage_upper_quantile: float = 1, norm_adv: bool = True, **kwargs, ): super().__init__(**kwargs) # Whether to normalize advantages within batch self.norm_adv = norm_adv # Clipping value for policy loss self.clip_ploss_coef = clip_ploss_coef self.clip_ploss_coef_base = clip_ploss_coef_base self.clip_ploss_coef_rate = clip_ploss_coef_rate # Clipping value for value loss self.clip_vloss_coef = clip_vloss_coef # Discount factor for diffusion MDP self.gamma_denoising = gamma_denoising # Quantiles for clipping advantages self.clip_advantage_lower_quantile = clip_advantage_lower_quantile self.clip_advantage_upper_quantile = clip_advantage_upper_quantile def loss( self, obs, chains_prev, chains_next, denoising_inds, returns, oldvalues, advantages, oldlogprobs, use_bc_loss=False, reward_horizon=4, ): """ PPO loss obs: dict with key state/rgb; more recent obs at the end state: (B, To, Do) rgb: (B, To, C, H, W) chains: (B, K+1, Ta, Da) returns: (B, ) values: (B, ) advantages: (B,) oldlogprobs: (B, K, Ta, Da) use_bc_loss: whether to add BC regularization loss reward_horizon: action horizon that backpropagates gradient """ # Get new logprobs for denoising steps from T-1 to 0 - entropy is fixed fod diffusion newlogprobs, eta = self.get_logprobs_subsample( obs, chains_prev, chains_next, denoising_inds, get_ent=True, ) entropy_loss = -eta.mean() newlogprobs = newlogprobs.clamp(min=-5, max=2) oldlogprobs = oldlogprobs.clamp(min=-5, max=2) # only backpropagate through the earlier steps (e.g., ones actually executed in the environment) newlogprobs = newlogprobs[:, :reward_horizon, :] oldlogprobs = oldlogprobs[:, :reward_horizon, :] # Get the logprobs - batch over B and denoising steps newlogprobs = newlogprobs.mean(dim=(-1, -2)).view(-1) oldlogprobs = oldlogprobs.mean(dim=(-1, -2)).view(-1) bc_loss = 0 if use_bc_loss: # See Eqn. 2 of https://arxiv.org/pdf/2403.03949.pdf # Give a reward for maximizing probability of teacher policy's action with current policy. # Actions are chosen along trajectory induced by current policy. # Get counterfactual teacher actions samples = self.forward( cond=obs, deterministic=False, return_chain=True, use_base_policy=True, ) # Get logprobs of teacher actions under this policy bc_logprobs = self.get_logprobs( obs, samples.chains, get_ent=False, use_base_policy=False, ) bc_logprobs = bc_logprobs.clamp(min=-5, max=2) bc_logprobs = bc_logprobs.mean(dim=(-1, -2)).view(-1) bc_loss = -bc_logprobs.mean() # normalize advantages if self.norm_adv: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Clip advantages by 5th and 95th percentile advantage_min = torch.quantile(advantages, self.clip_advantage_lower_quantile) advantage_max = torch.quantile(advantages, self.clip_advantage_upper_quantile) advantages = advantages.clamp(min=advantage_min, max=advantage_max) # denoising discount discount = torch.tensor( [ self.gamma_denoising ** (self.ft_denoising_steps - i - 1) for i in denoising_inds ] ).to(self.device) advantages *= discount # get ratio logratio = newlogprobs - oldlogprobs ratio = logratio.exp() # exponentially interpolate between the base and the current clipping value over denoising steps and repeat t = (denoising_inds.float() / (self.ft_denoising_steps - 1)).to(self.device) if self.ft_denoising_steps > 1: clip_ploss_coef = self.clip_ploss_coef_base + ( self.clip_ploss_coef - self.clip_ploss_coef_base ) * (torch.exp(self.clip_ploss_coef_rate * t) - 1) / ( math.exp(self.clip_ploss_coef_rate) - 1 ) else: clip_ploss_coef = t # get kl difference and whether value clipped with torch.no_grad(): # old_approx_kl: the approximate Kullback–Leibler divergence, measured by (-logratio).mean(), which corresponds to the k1 estimator in John Schulman’s blog post on approximating KL http://joschu.net/blog/kl-approx.html # approx_kl: better alternative to old_approx_kl measured by (logratio.exp() - 1) - logratio, which corresponds to the k3 estimator in approximating KL http://joschu.net/blog/kl-approx.html # old_approx_kl = (-logratio).mean() approx_kl = ((ratio - 1) - logratio).mean() clipfrac = ((ratio - 1.0).abs() > clip_ploss_coef).float().mean().item() # Policy loss with clipping pg_loss1 = -advantages * ratio pg_loss2 = -advantages * torch.clamp( ratio, 1 - clip_ploss_coef, 1 + clip_ploss_coef ) pg_loss = torch.max(pg_loss1, pg_loss2).mean() # Value loss optionally with clipping newvalues = self.critic(obs).view(-1) if self.clip_vloss_coef is not None: v_loss_unclipped = (newvalues - returns) ** 2 v_clipped = oldvalues + torch.clamp( newvalues - oldvalues, -self.clip_vloss_coef, self.clip_vloss_coef, ) v_loss_clipped = (v_clipped - returns) ** 2 v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) v_loss = 0.5 * v_loss_max.mean() else: v_loss = 0.5 * ((newvalues - returns) ** 2).mean() return ( pg_loss, entropy_loss, v_loss, clipfrac, approx_kl.item(), ratio.mean().item(), bc_loss, eta.mean().item(), )