193 lines
5.7 KiB
Python
193 lines
5.7 KiB
Python
"""
|
|
Implicit diffusion Q-learning (IDQL) for diffusion policy.
|
|
|
|
"""
|
|
|
|
import logging
|
|
import torch
|
|
import einops
|
|
import copy
|
|
|
|
import torch.nn.functional as F
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
from model.diffusion.diffusion_rwr import RWRDiffusion
|
|
|
|
|
|
def expectile_loss(diff, expectile=0.8):
|
|
weight = torch.where(diff > 0, expectile, (1 - expectile))
|
|
return weight * (diff**2)
|
|
|
|
|
|
def soft_update(target, source, tau):
|
|
for target_param, param in zip(target.parameters(), source.parameters()):
|
|
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
|
|
|
|
|
|
class IDQLDiffusion(RWRDiffusion):
|
|
|
|
def __init__(
|
|
self,
|
|
actor,
|
|
critic_q,
|
|
critic_v,
|
|
**kwargs,
|
|
):
|
|
super().__init__(network=actor, **kwargs)
|
|
self.critic_q = critic_q.to(self.device)
|
|
self.target_q = copy.deepcopy(critic_q)
|
|
self.critic_v = critic_v.to(self.device)
|
|
|
|
# assign actor
|
|
self.actor = self.network
|
|
|
|
# ---------- RL training ----------#
|
|
|
|
def compute_advantages(self, obs, actions):
|
|
|
|
# get current Q-function, stop gradient
|
|
with torch.no_grad():
|
|
current_q1, current_q2 = self.target_q(obs, actions)
|
|
q = torch.min(current_q1, current_q2)
|
|
|
|
# get the current V-function
|
|
v = self.critic_v(obs).reshape(-1)
|
|
|
|
# compute advantage
|
|
adv = q - v
|
|
|
|
return adv
|
|
|
|
def loss_critic_v(self, obs, actions):
|
|
adv = self.compute_advantages(obs, actions)
|
|
|
|
# get the value loss
|
|
v_loss = expectile_loss(adv).mean()
|
|
|
|
return v_loss
|
|
|
|
def loss_critic_q(self, obs, next_obs, actions, rewards, dones, gamma):
|
|
|
|
# get current Q-function
|
|
current_q1, current_q2 = self.critic_q(obs, actions)
|
|
|
|
# get the next V-function, stop gradient
|
|
with torch.no_grad():
|
|
next_v = self.critic_v(next_obs)
|
|
|
|
# terminal state mask
|
|
mask = 1 - dones
|
|
|
|
# flatten
|
|
rewards = rewards.view(-1)
|
|
next_v = next_v.view(-1)
|
|
mask = mask.view(-1)
|
|
|
|
# target value
|
|
discounted_q = rewards + gamma * next_v * mask
|
|
|
|
# Update critic
|
|
q_loss = torch.mean((current_q1 - discounted_q) ** 2) + torch.mean(
|
|
(current_q2 - discounted_q) ** 2
|
|
)
|
|
|
|
return q_loss
|
|
|
|
def update_target_critic(self, tau):
|
|
soft_update(self.target_q, self.critic_q, tau)
|
|
|
|
# override
|
|
def p_losses(
|
|
self,
|
|
x_start,
|
|
cond,
|
|
t,
|
|
):
|
|
device = x_start.device
|
|
|
|
# Forward process
|
|
noise = torch.randn_like(x_start, device=device)
|
|
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
|
|
|
# Predict
|
|
x_recon = self.network(x_noisy, t, cond=cond)
|
|
|
|
# Loss with mask
|
|
if self.predict_epsilon:
|
|
loss = F.mse_loss(x_recon, noise, reduction="none")
|
|
else:
|
|
loss = F.mse_loss(x_recon, x_start, reduction="none")
|
|
loss = einops.reduce(loss, "b h d -> b", "mean")
|
|
return loss.mean()
|
|
|
|
# ---------- Sampling ----------#``
|
|
|
|
# override
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
cond,
|
|
deterministic=False,
|
|
num_sample=10,
|
|
critic_hyperparam=0.7, # sampling weight for implicit policy
|
|
use_expectile_exploration=True,
|
|
):
|
|
"""assume state-only, no rgb in cond"""
|
|
# repeat obs num_sample times along dim 0
|
|
cond_shape_repeat_dims = tuple(1 for _ in cond["state"].shape)
|
|
B, T, D = cond["state"].shape
|
|
S = num_sample
|
|
cond_repeat = cond["state"][None].repeat(num_sample, *cond_shape_repeat_dims)
|
|
cond_repeat = cond_repeat.view(-1, T, D) # [B*S, T, D]
|
|
|
|
# for eval, use less noisy samples --- there is still DDPM noise, but final action uses small min_sampling_std
|
|
samples = super(IDQLDiffusion, self).forward(
|
|
{"state": cond_repeat},
|
|
deterministic=deterministic,
|
|
)
|
|
_, H, A = samples.shape
|
|
|
|
# get current Q-function
|
|
current_q1, current_q2 = self.target_q({"state": cond_repeat}, samples)
|
|
q = torch.min(current_q1, current_q2)
|
|
q = q.view(S, B)
|
|
|
|
# Use argmax
|
|
if deterministic or (not use_expectile_exploration):
|
|
# gather the best sample -- filter out suboptimal Q during inference
|
|
best_indices = q.argmax(0)
|
|
samples_expanded = samples.view(S, B, H, A)
|
|
|
|
# dummy dimension @ dim 0 for batched indexing
|
|
sample_indices = best_indices[None, :, None, None] # [1, B, 1, 1]
|
|
sample_indices = sample_indices.repeat(S, 1, H, A)
|
|
|
|
samples_best = torch.gather(samples_expanded, 0, sample_indices)
|
|
# Sample as an implicit policy for exploration
|
|
else:
|
|
# get the current value function for probabilistic exploration
|
|
current_v = self.critic_v({"state": cond_repeat})
|
|
v = current_v.view(S, B)
|
|
adv = q - v
|
|
|
|
# Compute weights for sampling
|
|
samples_expanded = samples.view(S, B, H, A)
|
|
|
|
# expectile exploration policy
|
|
tau_weights = torch.where(adv > 0, critic_hyperparam, 1 - critic_hyperparam)
|
|
tau_weights = tau_weights / tau_weights.sum(0) # normalize
|
|
|
|
# select a sample from DP probabilistically -- sample index per batch and compile
|
|
sample_indices = torch.multinomial(tau_weights.T, 1) # [B, 1]
|
|
|
|
# dummy dimension @ dim 0 for batched indexing
|
|
sample_indices = sample_indices[None, :, None] # [1, B, 1, 1]
|
|
sample_indices = sample_indices.repeat(S, 1, H, A)
|
|
|
|
samples_best = torch.gather(samples_expanded, 0, sample_indices)
|
|
|
|
# squeeze dummy dimension
|
|
samples = samples_best[0]
|
|
return samples
|