155 lines
4.5 KiB
Python
155 lines
4.5 KiB
Python
"""
|
||
Diffusion policy gradient with exact likelihood estimation.
|
||
|
||
Based on score_sde_pytorch https://github.com/yang-song/score_sde_pytorch
|
||
|
||
To: observation sequence length
|
||
Ta: action chunk size
|
||
Do: observation dimension
|
||
Da: action dimension
|
||
|
||
"""
|
||
|
||
import torch
|
||
import logging
|
||
|
||
log = logging.getLogger(__name__)
|
||
from .diffusion_ppo import PPODiffusion
|
||
from .exact_likelihood import get_likelihood_fn
|
||
|
||
|
||
class PPOExactDiffusion(PPODiffusion):
|
||
|
||
def __init__(
|
||
self,
|
||
sde,
|
||
sde_hutchinson_type="Rademacher",
|
||
sde_rtol=1e-4,
|
||
sde_atol=1e-4,
|
||
sde_eps=1e-4,
|
||
sde_step_size=1e-3,
|
||
sde_method="RK23",
|
||
sde_continuous=False,
|
||
sde_probability_flow=False,
|
||
sde_num_epsilon=1,
|
||
sde_min_beta=1e-2,
|
||
**kwargs,
|
||
):
|
||
super().__init__(**kwargs)
|
||
self.sde = sde
|
||
self.sde.set_betas(
|
||
self.betas,
|
||
sde_min_beta,
|
||
)
|
||
|
||
# set up likelihood function
|
||
self.likelihood_fn = get_likelihood_fn(
|
||
sde,
|
||
hutchinson_type=sde_hutchinson_type,
|
||
rtol=sde_rtol,
|
||
atol=sde_atol,
|
||
eps=sde_eps,
|
||
step_size=sde_step_size,
|
||
method=sde_method,
|
||
continuous=sde_continuous,
|
||
probability_flow=sde_probability_flow,
|
||
predict_epsilon=self.predict_epsilon,
|
||
num_epsilon=sde_num_epsilon,
|
||
)
|
||
|
||
def get_exact_logprobs(self, cond, samples):
|
||
"""Use torchdiffeq
|
||
|
||
samples: (B x Ta x Da)
|
||
"""
|
||
# TODO: image input
|
||
return self.likelihood_fn(
|
||
self.actor,
|
||
self.actor_ft,
|
||
samples,
|
||
self.denoising_steps,
|
||
self.ft_denoising_steps,
|
||
cond=cond,
|
||
)
|
||
|
||
def loss(
|
||
self,
|
||
obs,
|
||
samples,
|
||
returns,
|
||
oldvalues,
|
||
advantages,
|
||
oldlogprobs,
|
||
use_bc_loss=False,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
PPO loss
|
||
|
||
obs: dict with key state/rgb; more recent obs at the end
|
||
state: (B, To, Do)
|
||
samples: (B, Ta, Da)
|
||
returns: (B, )
|
||
values: (B, )
|
||
advantages: (B,)
|
||
oldlogprobs: (B, )
|
||
"""
|
||
# Get new logprobs for final x
|
||
newlogprobs = self.get_exact_logprobs(obs, samples)
|
||
newlogprobs = newlogprobs.clamp(min=-5, max=2)
|
||
oldlogprobs = oldlogprobs.clamp(min=-5, max=2)
|
||
|
||
bc_loss = 0
|
||
if use_bc_loss:
|
||
raise NotImplementedError
|
||
|
||
# get ratio
|
||
logratio = newlogprobs - oldlogprobs
|
||
ratio = logratio.exp()
|
||
|
||
# get kl difference and whether value clipped
|
||
with torch.no_grad():
|
||
# old_approx_kl: the approximate Kullback–Leibler divergence, measured by (-logratio).mean(), which corresponds to the k1 estimator in John Schulman’s blog post on approximating KL http://joschu.net/blog/kl-approx.html
|
||
# approx_kl: better alternative to old_approx_kl measured by (logratio.exp() - 1) - logratio, which corresponds to the k3 estimator in approximating KL http://joschu.net/blog/kl-approx.html
|
||
# old_approx_kl = (-logratio).mean()
|
||
approx_kl = ((ratio - 1) - logratio).mean()
|
||
clipfrac = (
|
||
((ratio - 1.0).abs() > self.clip_ploss_coef).float().mean().item()
|
||
)
|
||
|
||
# normalize advantages
|
||
if self.norm_adv:
|
||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||
|
||
# Policy loss with clipping
|
||
pg_loss1 = -advantages * ratio
|
||
pg_loss2 = -advantages * torch.clamp(
|
||
ratio, 1 - self.clip_ploss_coef, 1 + self.clip_ploss_coef
|
||
)
|
||
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
|
||
|
||
# Value loss optionally with clipping
|
||
newvalues = self.critic(obs).view(-1)
|
||
if self.clip_vloss_coef is not None:
|
||
v_loss_unclipped = (newvalues - returns) ** 2
|
||
v_clipped = oldvalues + torch.clamp(
|
||
newvalues - oldvalues,
|
||
-self.clip_vloss_coef,
|
||
self.clip_vloss_coef,
|
||
)
|
||
v_loss_clipped = (v_clipped - returns) ** 2
|
||
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
|
||
v_loss = 0.5 * v_loss_max.mean()
|
||
else:
|
||
v_loss = 0.5 * ((newvalues - returns) ** 2).mean()
|
||
|
||
# entropy is maximized - only effective if residual is learned
|
||
return (
|
||
pg_loss,
|
||
v_loss,
|
||
clipfrac,
|
||
approx_kl.item(),
|
||
ratio.mean().item(),
|
||
bc_loss,
|
||
)
|