dppo/model/diffusion/diffusion_idql.py

194 lines
5.8 KiB
Python

"""
Implicit diffusion Q-learning (IDQL) for diffusion policy.
"""
import logging
import torch
import einops
import copy
import torch.nn.functional as F
log = logging.getLogger(__name__)
from model.diffusion.diffusion_rwr import RWRDiffusion
def expectile_loss(diff, expectile=0.8):
weight = torch.where(diff > 0, expectile, (1 - expectile))
return weight * (diff**2)
def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
class IDQLDiffusion(RWRDiffusion):
def __init__(
self,
actor,
critic_q,
critic_v,
**kwargs,
):
super().__init__(network=actor, **kwargs)
self.critic_q = critic_q.to(self.device)
self.target_q = copy.deepcopy(critic_q)
self.critic_v = critic_v.to(self.device)
# assign actor
self.actor = self.network
# ---------- RL training ----------#
def compute_advantages(self, obs, actions):
# get current Q-function, stop gradient
with torch.no_grad():
current_q1, current_q2 = self.target_q(obs, actions)
q = torch.min(current_q1, current_q2)
# get the current V-function
v = self.critic_v(obs).reshape(-1)
# compute advantage
adv = q - v
return adv
def loss_critic_v(self, obs, actions):
adv = self.compute_advantages(obs, actions)
# get the value loss
v_loss = expectile_loss(adv).mean()
return v_loss
def loss_critic_q(self, obs, next_obs, actions, rewards, dones, gamma):
# get current Q-function
current_q1, current_q2 = self.critic_q(obs, actions)
# get the next V-function, stop gradient
with torch.no_grad():
next_v = self.critic_v(next_obs)
# terminal state mask
mask = 1 - dones
# flatten
rewards = rewards.view(-1)
next_v = next_v.view(-1)
mask = mask.view(-1)
# target value
discounted_q = rewards + gamma * next_v * mask
# Update critic
q_loss = torch.mean((current_q1 - discounted_q) ** 2) + torch.mean(
(current_q2 - discounted_q) ** 2
)
return q_loss
def update_target_critic(self, tau):
soft_update(self.target_q, self.critic_q, tau)
# override
def p_losses(
self,
x_start,
cond,
t,
):
"""not reward-weighted, same as diffusion.py"""
device = x_start.device
# Forward process
noise = torch.randn_like(x_start, device=device)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
# Predict
x_recon = self.network(x_noisy, t, cond=cond)
# Loss with mask
if self.predict_epsilon:
loss = F.mse_loss(x_recon, noise, reduction="none")
else:
loss = F.mse_loss(x_recon, x_start, reduction="none")
loss = einops.reduce(loss, "b h d -> b", "mean")
return loss.mean()
# ---------- Sampling ----------#``
# override
@torch.no_grad()
def forward(
self,
cond,
deterministic=False,
num_sample=10,
critic_hyperparam=0.7, # sampling weight for implicit policy
use_expectile_exploration=True,
):
"""assume state-only, no rgb in cond"""
# repeat obs num_sample times along dim 0
cond_shape_repeat_dims = tuple(1 for _ in cond["state"].shape)
B, T, D = cond["state"].shape
S = num_sample
cond_repeat = cond["state"][None].repeat(num_sample, *cond_shape_repeat_dims)
cond_repeat = cond_repeat.view(-1, T, D) # [B*S, T, D]
# for eval, use less noisy samples --- there is still DDPM noise, but final action uses small min_sampling_std
samples = super(IDQLDiffusion, self).forward(
{"state": cond_repeat},
deterministic=deterministic,
)
_, H, A = samples.shape
# get current Q-function
current_q1, current_q2 = self.target_q({"state": cond_repeat}, samples)
q = torch.min(current_q1, current_q2)
q = q.view(S, B)
# Use argmax
if deterministic or (not use_expectile_exploration):
# gather the best sample -- filter out suboptimal Q during inference
best_indices = q.argmax(0)
samples_expanded = samples.view(S, B, H, A)
# dummy dimension @ dim 0 for batched indexing
sample_indices = best_indices[None, :, None, None] # [1, B, 1, 1]
sample_indices = sample_indices.repeat(S, 1, H, A)
samples_best = torch.gather(samples_expanded, 0, sample_indices)
# Sample as an implicit policy for exploration
else:
# get the current value function for probabilistic exploration
current_v = self.critic_v({"state": cond_repeat})
v = current_v.view(S, B)
adv = q - v
# Compute weights for sampling
samples_expanded = samples.view(S, B, H, A)
# expectile exploration policy
tau_weights = torch.where(adv > 0, critic_hyperparam, 1 - critic_hyperparam)
tau_weights = tau_weights / tau_weights.sum(0) # normalize
# select a sample from DP probabilistically -- sample index per batch and compile
sample_indices = torch.multinomial(tau_weights.T, 1) # [B, 1]
# dummy dimension @ dim 0 for batched indexing
sample_indices = sample_indices[None, :, None] # [1, B, 1, 1]
sample_indices = sample_indices.repeat(S, 1, H, A)
samples_best = torch.gather(samples_expanded, 0, sample_indices)
# squeeze dummy dimension
samples = samples_best[0]
return samples