dppo/model/rl/gaussian_ppo.py
2024-09-03 21:03:27 -04:00

121 lines
3.6 KiB
Python

"""
PPO for Gaussian policy.
"""
from typing import Optional
import torch
from model.rl.gaussian_vpg import VPG_Gaussian
class PPO_Gaussian(VPG_Gaussian):
def __init__(
self,
clip_ploss_coef: float,
clip_vloss_coef: Optional[float] = None,
norm_adv: Optional[bool] = True,
**kwargs,
):
super().__init__(**kwargs)
# Whether to normalize advantages within batch
self.norm_adv = norm_adv
# Clipping value for policy loss
self.clip_ploss_coef = clip_ploss_coef
# Clipping value for value loss
self.clip_vloss_coef = clip_vloss_coef
def loss(
self,
obs,
actions,
returns,
oldvalues,
advantages,
oldlogprobs,
use_bc_loss=False,
):
"""
PPO loss
obs: (B, obs_step, obs_dim)
actions: (B, horizon_step, action_dim)
returns: (B, )
values: (B, )
advantages: (B,)
oldlogprobs: (B, )
"""
newlogprobs, entropy, std = self.get_logprobs(obs, actions)
newlogprobs = newlogprobs.clamp(min=-5, max=2)
oldlogprobs = oldlogprobs.clamp(min=-5, max=2)
entropy_loss = -entropy
# get ratio
logratio = newlogprobs - oldlogprobs
ratio = logratio.exp()
# get kl difference and whether value clipped
with torch.no_grad():
approx_kl = ((ratio - 1) - logratio).nanmean()
clipfrac = (
((ratio - 1.0).abs() > self.clip_ploss_coef).float().mean().item()
)
# normalize advantages
if self.norm_adv:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Policy loss with clipping
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(
ratio, 1 - self.clip_ploss_coef, 1 + self.clip_ploss_coef
)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# Value loss optionally with clipping
newvalues = self.critic(obs).view(-1)
if self.clip_vloss_coef is not None:
v_loss_unclipped = (newvalues - returns) ** 2
v_clipped = oldvalues + torch.clamp(
newvalues - oldvalues,
-self.clip_vloss_coef,
self.clip_vloss_coef,
)
v_loss_clipped = (v_clipped - returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((newvalues - returns) ** 2).mean()
bc_loss = 0.0
if use_bc_loss:
# See Eqn. 2 of https://arxiv.org/pdf/2403.03949.pdf
# Give a reward for maximizing probability of teacher policy's action with current policy.
# Actions are chosen along trajectory induced by current policy.
# Get counterfactual teacher actions
samples = self.forward(
cond=obs.float()
.unsqueeze(1)
.to(self.device), # B x horizon=1 x obs_dim
deterministic=False,
use_base_policy=True,
)
# Get logprobs of teacher actions under this policy
bc_logprobs, _, _ = self.get_logprobs(obs, samples, use_base_policy=False)
bc_logprobs = bc_logprobs.clamp(min=-5, max=2)
bc_loss = -bc_logprobs.mean()
return (
pg_loss,
entropy_loss,
v_loss,
clipfrac,
approx_kl.item(),
ratio.mean().item(),
bc_loss,
std.item(),
)