dppo/model/diffusion/diffusion_ppo_exact.py
Allen Z. Ren e0842e71dc
v0.5 to main (#10)
* v0.5 (#9)

* update idql configs

* update awr configs

* update dipo configs

* update qsm configs

* update dqm configs

* update project version to 0.5.0
2024-10-07 16:35:13 -04:00

160 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Diffusion policy gradient with exact likelihood estimation.
Based on score_sde_pytorch https://github.com/yang-song/score_sde_pytorch
To: observation sequence length
Ta: action chunk size
Do: observation dimension
Da: action dimension
"""
import torch
import logging
log = logging.getLogger(__name__)
from .diffusion_vpg import VPGDiffusion
from .exact_likelihood import get_likelihood_fn
class PPOExactDiffusion(VPGDiffusion):
def __init__(
self,
sde,
clip_ploss_coef,
clip_vloss_coef=None,
norm_adv=True,
sde_hutchinson_type="Rademacher",
sde_rtol=1e-4,
sde_atol=1e-4,
sde_eps=1e-4,
sde_step_size=1e-3,
sde_method="RK23",
sde_continuous=False,
sde_probability_flow=False,
sde_num_epsilon=1,
sde_min_beta=1e-2,
**kwargs,
):
super().__init__(**kwargs)
self.sde = sde
self.sde.set_betas(
self.betas,
sde_min_beta,
)
self.clip_ploss_coef = clip_ploss_coef
self.clip_vloss_coef = clip_vloss_coef
self.norm_adv = norm_adv
# set up likelihood function
self.likelihood_fn = get_likelihood_fn(
sde,
hutchinson_type=sde_hutchinson_type,
rtol=sde_rtol,
atol=sde_atol,
eps=sde_eps,
step_size=sde_step_size,
method=sde_method,
continuous=sde_continuous,
probability_flow=sde_probability_flow,
predict_epsilon=self.predict_epsilon,
num_epsilon=sde_num_epsilon,
)
def get_exact_logprobs(self, cond, samples):
"""Use torchdiffeq
samples: (B x Ta x Da)
"""
return self.likelihood_fn(
self.actor,
self.actor_ft,
samples,
self.denoising_steps,
self.ft_denoising_steps,
cond=cond,
)
def loss(
self,
obs,
samples,
returns,
oldvalues,
advantages,
oldlogprobs,
use_bc_loss=False,
**kwargs,
):
"""
PPO loss
obs: dict with key state/rgb; more recent obs at the end
state: (B, To, Do)
samples: (B, Ta, Da)
returns: (B, )
values: (B, )
advantages: (B,)
oldlogprobs: (B, )
"""
# Get new logprobs for final x
newlogprobs = self.get_exact_logprobs(obs, samples)
newlogprobs = newlogprobs.clamp(min=-5, max=2)
oldlogprobs = oldlogprobs.clamp(min=-5, max=2)
bc_loss = 0
if use_bc_loss:
raise NotImplementedError
# get ratio
logratio = newlogprobs - oldlogprobs
ratio = logratio.exp()
# get kl difference and whether value clipped
with torch.no_grad():
# old_approx_kl: the approximate KullbackLeibler divergence, measured by (-logratio).mean(), which corresponds to the k1 estimator in John Schulmans blog post on approximating KL http://joschu.net/blog/kl-approx.html
# approx_kl: better alternative to old_approx_kl measured by (logratio.exp() - 1) - logratio, which corresponds to the k3 estimator in approximating KL http://joschu.net/blog/kl-approx.html
# old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = (
((ratio - 1.0).abs() > self.clip_ploss_coef).float().mean().item()
)
# normalize advantages
if self.norm_adv:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Policy loss with clipping
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(
ratio, 1 - self.clip_ploss_coef, 1 + self.clip_ploss_coef
)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# Value loss optionally with clipping
newvalues = self.critic(obs).view(-1)
if self.clip_vloss_coef is not None:
v_loss_unclipped = (newvalues - returns) ** 2
v_clipped = oldvalues + torch.clamp(
newvalues - oldvalues,
-self.clip_vloss_coef,
self.clip_vloss_coef,
)
v_loss_clipped = (v_clipped - returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((newvalues - returns) ** 2).mean()
# entropy is maximized - only effective if residual is learned
return (
pg_loss,
v_loss,
clipfrac,
approx_kl.item(),
ratio.mean().item(),
bc_loss,
)