113 lines
3.0 KiB
Python
113 lines
3.0 KiB
Python
"""
|
|
PPO for GMM policy.
|
|
|
|
To: observation sequence length
|
|
Ta: action chunk size
|
|
Do: observation dimension
|
|
Da: action dimension
|
|
|
|
C: image channels
|
|
H, W: image height and width
|
|
|
|
"""
|
|
|
|
from typing import Optional
|
|
import torch
|
|
from model.rl.gmm_vpg import VPG_GMM
|
|
|
|
|
|
class PPO_GMM(VPG_GMM):
|
|
|
|
def __init__(
|
|
self,
|
|
clip_ploss_coef: float,
|
|
clip_vloss_coef: Optional[float] = None,
|
|
norm_adv: Optional[bool] = True,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
|
|
# Whether to normalize advantages within batch
|
|
self.norm_adv = norm_adv
|
|
|
|
# Clipping value for policy loss
|
|
self.clip_ploss_coef = clip_ploss_coef
|
|
|
|
# Clipping value for value loss
|
|
self.clip_vloss_coef = clip_vloss_coef
|
|
|
|
def loss(
|
|
self,
|
|
obs,
|
|
actions,
|
|
returns,
|
|
oldvalues,
|
|
advantages,
|
|
oldlogprobs,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
PPO loss
|
|
|
|
obs: dict with key state/rgb; more recent obs at the end
|
|
state: (B, To, Do)
|
|
rgb: (B, To, C, H, W)
|
|
actions: (B, Ta, Da)
|
|
returns: (B, )
|
|
values: (B, )
|
|
advantages: (B,)
|
|
oldlogprobs: (B, )
|
|
"""
|
|
newlogprobs, entropy, std = self.get_logprobs(obs, actions)
|
|
newlogprobs = newlogprobs.clamp(min=-5, max=2)
|
|
oldlogprobs = oldlogprobs.clamp(min=-5, max=2)
|
|
entropy_loss = -entropy.mean()
|
|
|
|
# get ratio
|
|
logratio = newlogprobs - oldlogprobs
|
|
ratio = logratio.exp()
|
|
|
|
# get kl difference and whether value clipped
|
|
with torch.no_grad():
|
|
approx_kl = ((ratio - 1) - logratio).nanmean()
|
|
clipfrac = (
|
|
((ratio - 1.0).abs() > self.clip_ploss_coef).float().mean().item()
|
|
)
|
|
|
|
# normalize advantages
|
|
if self.norm_adv:
|
|
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
|
|
|
# Policy loss with clipping
|
|
pg_loss1 = -advantages * ratio
|
|
pg_loss2 = -advantages * torch.clamp(
|
|
ratio, 1 - self.clip_ploss_coef, 1 + self.clip_ploss_coef
|
|
)
|
|
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
|
|
|
|
# Value loss optionally with clipping
|
|
newvalues = self.critic(obs).view(-1)
|
|
if self.clip_vloss_coef is not None:
|
|
v_loss_unclipped = (newvalues - returns) ** 2
|
|
v_clipped = oldvalues + torch.clamp(
|
|
newvalues - oldvalues,
|
|
-self.clip_vloss_coef,
|
|
self.clip_vloss_coef,
|
|
)
|
|
v_loss_clipped = (v_clipped - returns) ** 2
|
|
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
|
|
v_loss = 0.5 * v_loss_max.mean()
|
|
else:
|
|
v_loss = 0.5 * ((newvalues - returns) ** 2).mean()
|
|
bc_loss = 0
|
|
return (
|
|
pg_loss,
|
|
entropy_loss,
|
|
v_loss,
|
|
clipfrac,
|
|
approx_kl.item(),
|
|
ratio.mean().item(),
|
|
bc_loss,
|
|
std.item(),
|
|
)
|