dppo/model/diffusion/diffusion_ppo_exact.py
2024-09-11 21:09:17 -04:00

155 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Diffusion policy gradient with exact likelihood estimation.
Based on score_sde_pytorch https://github.com/yang-song/score_sde_pytorch
To: observation sequence length
Ta: action chunk size
Do: observation dimension
Da: action dimension
"""
import torch
import logging
log = logging.getLogger(__name__)
from .diffusion_ppo import PPODiffusion
from .exact_likelihood import get_likelihood_fn
class PPOExactDiffusion(PPODiffusion):
def __init__(
self,
sde,
sde_hutchinson_type="Rademacher",
sde_rtol=1e-4,
sde_atol=1e-4,
sde_eps=1e-4,
sde_step_size=1e-3,
sde_method="RK23",
sde_continuous=False,
sde_probability_flow=False,
sde_num_epsilon=1,
sde_min_beta=1e-2,
**kwargs,
):
super().__init__(**kwargs)
self.sde = sde
self.sde.set_betas(
self.betas,
sde_min_beta,
)
# set up likelihood function
self.likelihood_fn = get_likelihood_fn(
sde,
hutchinson_type=sde_hutchinson_type,
rtol=sde_rtol,
atol=sde_atol,
eps=sde_eps,
step_size=sde_step_size,
method=sde_method,
continuous=sde_continuous,
probability_flow=sde_probability_flow,
predict_epsilon=self.predict_epsilon,
num_epsilon=sde_num_epsilon,
)
def get_exact_logprobs(self, cond, samples):
"""Use torchdiffeq
samples: (B x Ta x Da)
"""
# TODO: image input
return self.likelihood_fn(
self.actor,
self.actor_ft,
samples,
self.denoising_steps,
self.ft_denoising_steps,
cond=cond,
)
def loss(
self,
obs,
samples,
returns,
oldvalues,
advantages,
oldlogprobs,
use_bc_loss=False,
**kwargs,
):
"""
PPO loss
obs: dict with key state/rgb; more recent obs at the end
state: (B, To, Do)
samples: (B, Ta, Da)
returns: (B, )
values: (B, )
advantages: (B,)
oldlogprobs: (B, )
"""
# Get new logprobs for final x
newlogprobs = self.get_exact_logprobs(obs, samples)
newlogprobs = newlogprobs.clamp(min=-5, max=2)
oldlogprobs = oldlogprobs.clamp(min=-5, max=2)
bc_loss = 0
if use_bc_loss:
raise NotImplementedError
# get ratio
logratio = newlogprobs - oldlogprobs
ratio = logratio.exp()
# get kl difference and whether value clipped
with torch.no_grad():
# old_approx_kl: the approximate KullbackLeibler divergence, measured by (-logratio).mean(), which corresponds to the k1 estimator in John Schulmans blog post on approximating KL http://joschu.net/blog/kl-approx.html
# approx_kl: better alternative to old_approx_kl measured by (logratio.exp() - 1) - logratio, which corresponds to the k3 estimator in approximating KL http://joschu.net/blog/kl-approx.html
# old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = (
((ratio - 1.0).abs() > self.clip_ploss_coef).float().mean().item()
)
# normalize advantages
if self.norm_adv:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Policy loss with clipping
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(
ratio, 1 - self.clip_ploss_coef, 1 + self.clip_ploss_coef
)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# Value loss optionally with clipping
newvalues = self.critic(obs).view(-1)
if self.clip_vloss_coef is not None:
v_loss_unclipped = (newvalues - returns) ** 2
v_clipped = oldvalues + torch.clamp(
newvalues - oldvalues,
-self.clip_vloss_coef,
self.clip_vloss_coef,
)
v_loss_clipped = (v_clipped - returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((newvalues - returns) ** 2).mean()
# entropy is maximized - only effective if residual is learned
return (
pg_loss,
v_loss,
clipfrac,
approx_kl.item(),
ratio.mean().item(),
bc_loss,
)