* Sampling over both env and denoising steps in DPPO updates (#13) * sample one from each chain * full random sampling * Add Proficient Human (PH) Configs and Pipeline (#16) * fix missing cfg * add ph config * fix how terminated flags are added to buffer in ibrl * add ph config * offline calql for 1M gradient updates * bug fix: number of calql online gradient steps is the number of new transitions collected * add sample config for DPPO with ta=1 * Sampling over both env and denoising steps in DPPO updates (#13) * sample one from each chain * full random sampling * fix diffusion loss when predicting initial noise * fix dppo inds * fix typo * remove print statement --------- Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu> Co-authored-by: allenzren <allen.ren@princeton.edu> * update robomimic configs * better calql formulation * optimize calql and ibrl training * optimize data transfer in ppo agents * add kitchen configs * re-organize config folders, rerun calql and rlpd * add scratch gym locomotion configs * add kitchen installation dependencies * use truncated for termination in furniture env * update furniture and gym configs * update README and dependencies with kitchen * add url for new data and checkpoints * update demo RL configs * update batch sizes for furniture unet configs * raise error about dropout in residual mlp * fix observation bug in bc loss --------- Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com> Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
503 lines
18 KiB
Python
503 lines
18 KiB
Python
"""
|
|
Policy gradient with diffusion policy. VPG: vanilla policy gradient
|
|
|
|
K: number of denoising steps
|
|
To: observation sequence length
|
|
Ta: action chunk size
|
|
Do: observation dimension
|
|
Da: action dimension
|
|
|
|
C: image channels
|
|
H, W: image height and width
|
|
|
|
"""
|
|
|
|
import copy
|
|
import torch
|
|
import logging
|
|
|
|
log = logging.getLogger(__name__)
|
|
import torch.nn.functional as F
|
|
|
|
from model.diffusion.diffusion import DiffusionModel, Sample
|
|
from model.diffusion.sampling import make_timesteps, extract
|
|
from torch.distributions import Normal
|
|
|
|
|
|
class VPGDiffusion(DiffusionModel):
|
|
|
|
def __init__(
|
|
self,
|
|
actor,
|
|
critic,
|
|
ft_denoising_steps,
|
|
ft_denoising_steps_d=0,
|
|
ft_denoising_steps_t=0,
|
|
network_path=None,
|
|
# modifying denoising schedule
|
|
min_sampling_denoising_std=0.1,
|
|
min_logprob_denoising_std=0.1,
|
|
# eta in DDIM
|
|
eta=None,
|
|
learn_eta=False,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
network=actor,
|
|
network_path=network_path,
|
|
**kwargs,
|
|
)
|
|
assert ft_denoising_steps <= self.denoising_steps
|
|
assert ft_denoising_steps <= self.ddim_steps if self.use_ddim else True
|
|
assert not (learn_eta and not self.use_ddim), "Cannot learn eta with DDPM."
|
|
|
|
# Number of denoising steps to use with fine-tuned model. Thus denoising_step - ft_denoising_steps is the number of denoising steps to use with original model.
|
|
self.ft_denoising_steps = ft_denoising_steps
|
|
self.ft_denoising_steps_d = ft_denoising_steps_d # annealing step size
|
|
self.ft_denoising_steps_t = ft_denoising_steps_t # annealing interval
|
|
self.ft_denoising_steps_cnt = 0
|
|
|
|
# Minimum std used in denoising process when sampling action - helps exploration
|
|
self.min_sampling_denoising_std = min_sampling_denoising_std
|
|
|
|
# Minimum std used in calculating denoising logprobs - for stability
|
|
self.min_logprob_denoising_std = min_logprob_denoising_std
|
|
|
|
# Learnable eta
|
|
self.learn_eta = learn_eta
|
|
if eta is not None:
|
|
self.eta = eta.to(self.device)
|
|
if not learn_eta:
|
|
for param in self.eta.parameters():
|
|
param.requires_grad = False
|
|
logging.info("Turned off gradients for eta")
|
|
|
|
# Re-name network to actor
|
|
self.actor = self.network
|
|
|
|
# Make a copy of the original model
|
|
self.actor_ft = copy.deepcopy(self.actor)
|
|
logging.info("Cloned model for fine-tuning")
|
|
|
|
# Turn off gradients for original model
|
|
for param in self.actor.parameters():
|
|
param.requires_grad = False
|
|
logging.info("Turned off gradients of the pretrained network")
|
|
logging.info(
|
|
f"Number of finetuned parameters: {sum(p.numel() for p in self.actor_ft.parameters() if p.requires_grad)}"
|
|
)
|
|
|
|
# Value function
|
|
self.critic = critic.to(self.device)
|
|
if network_path is not None:
|
|
checkpoint = torch.load(
|
|
network_path, map_location=self.device, weights_only=True
|
|
)
|
|
if "ema" not in checkpoint: # load trained RL model
|
|
self.load_state_dict(checkpoint["model"], strict=False)
|
|
logging.info("Loaded critic from %s", network_path)
|
|
|
|
# ---------- Sampling ----------#
|
|
|
|
def step(self):
|
|
"""
|
|
Anneal min_sampling_denoising_std and fine-tuning denoising steps
|
|
|
|
Current configs do not apply annealing
|
|
"""
|
|
# anneal min_sampling_denoising_std
|
|
if type(self.min_sampling_denoising_std) is not float:
|
|
self.min_sampling_denoising_std.step()
|
|
|
|
# anneal denoising steps
|
|
self.ft_denoising_steps_cnt += 1
|
|
if (
|
|
self.ft_denoising_steps_d > 0
|
|
and self.ft_denoising_steps_t > 0
|
|
and self.ft_denoising_steps_cnt % self.ft_denoising_steps_t == 0
|
|
):
|
|
self.ft_denoising_steps = max(
|
|
0, self.ft_denoising_steps - self.ft_denoising_steps_d
|
|
)
|
|
|
|
# update actor
|
|
self.actor = self.actor_ft
|
|
self.actor_ft = copy.deepcopy(self.actor)
|
|
for param in self.actor.parameters():
|
|
param.requires_grad = False
|
|
logging.info(
|
|
f"Finished annealing fine-tuning denoising steps to {self.ft_denoising_steps}"
|
|
)
|
|
|
|
def get_min_sampling_denoising_std(self):
|
|
if type(self.min_sampling_denoising_std) is float:
|
|
return self.min_sampling_denoising_std
|
|
else:
|
|
return self.min_sampling_denoising_std()
|
|
|
|
# override
|
|
def p_mean_var(
|
|
self,
|
|
x,
|
|
t,
|
|
cond,
|
|
index=None,
|
|
use_base_policy=False,
|
|
deterministic=False,
|
|
):
|
|
noise = self.actor(x, t, cond=cond)
|
|
if self.use_ddim:
|
|
ft_indices = torch.where(
|
|
index >= (self.ddim_steps - self.ft_denoising_steps)
|
|
)[0]
|
|
else:
|
|
ft_indices = torch.where(t < self.ft_denoising_steps)[0]
|
|
|
|
# Use base policy to query expert model, e.g. for imitation loss
|
|
actor = self.actor if use_base_policy else self.actor_ft
|
|
|
|
# overwrite noise for fine-tuning steps
|
|
if len(ft_indices) > 0:
|
|
cond_ft = {key: cond[key][ft_indices] for key in cond}
|
|
noise_ft = actor(x[ft_indices], t[ft_indices], cond=cond_ft)
|
|
noise[ft_indices] = noise_ft
|
|
|
|
# Predict x_0
|
|
if self.predict_epsilon:
|
|
if self.use_ddim:
|
|
"""
|
|
x₀ = (xₜ - √ (1-αₜ) ε )/ √ αₜ
|
|
"""
|
|
alpha = extract(self.ddim_alphas, index, x.shape)
|
|
alpha_prev = extract(self.ddim_alphas_prev, index, x.shape)
|
|
sqrt_one_minus_alpha = extract(
|
|
self.ddim_sqrt_one_minus_alphas, index, x.shape
|
|
)
|
|
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha**0.5)
|
|
else:
|
|
"""
|
|
x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε
|
|
"""
|
|
x_recon = (
|
|
extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
|
- extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise
|
|
)
|
|
else: # directly predicting x₀
|
|
x_recon = noise
|
|
if self.denoised_clip_value is not None:
|
|
x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value)
|
|
if self.use_ddim:
|
|
# re-calculate noise based on clamped x_recon - default to false in HF, but let's use it here
|
|
noise = (x - alpha ** (0.5) * x_recon) / sqrt_one_minus_alpha
|
|
|
|
# Clip epsilon for numerical stability in policy gradient - not sure if this is helpful yet, but the value can be huge sometimes. This has no effect if DDPM is used
|
|
if self.use_ddim and self.eps_clip_value is not None:
|
|
noise.clamp_(-self.eps_clip_value, self.eps_clip_value)
|
|
|
|
# Get mu
|
|
if self.use_ddim:
|
|
"""
|
|
μ = √ αₜ₋₁ x₀ + √(1-αₜ₋₁ - σₜ²) ε
|
|
"""
|
|
if deterministic:
|
|
etas = torch.zeros((x.shape[0], 1, 1)).to(x.device)
|
|
else:
|
|
etas = self.eta(cond).unsqueeze(1) # B x 1 x (Da or 1)
|
|
sigma = (
|
|
etas
|
|
* ((1 - alpha_prev) / (1 - alpha) * (1 - alpha / alpha_prev)) ** 0.5
|
|
).clamp_(min=1e-10)
|
|
dir_xt_coef = (1.0 - alpha_prev - sigma**2).clamp_(min=0).sqrt()
|
|
mu = (alpha_prev**0.5) * x_recon + dir_xt_coef * noise
|
|
var = sigma**2
|
|
logvar = torch.log(var)
|
|
else:
|
|
"""
|
|
μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
|
|
"""
|
|
mu = (
|
|
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
|
|
+ extract(self.ddpm_mu_coef2, t, x.shape) * x
|
|
)
|
|
logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
|
|
etas = torch.ones_like(mu).to(mu.device) # always one for DDPM
|
|
return mu, logvar, etas
|
|
|
|
# override
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
cond,
|
|
deterministic=False,
|
|
return_chain=True,
|
|
use_base_policy=False,
|
|
):
|
|
"""
|
|
Forward pass for sampling actions.
|
|
|
|
Args:
|
|
cond: dict with key state/rgb; more recent obs at the end
|
|
state: (B, To, Do)
|
|
rgb: (B, To, C, H, W)
|
|
deterministic: If true, then std=0 with DDIM, or with DDPM, use normal schedule (instead of clipping at a higher value)
|
|
return_chain: whether to return the entire chain of denoised actions
|
|
use_base_policy: whether to use the frozen pre-trained policy instead
|
|
Return:
|
|
Sample: namedtuple with fields:
|
|
trajectories: (B, Ta, Da)
|
|
chain: (B, K + 1, Ta, Da)
|
|
"""
|
|
device = self.betas.device
|
|
sample_data = cond["state"] if "state" in cond else cond["rgb"]
|
|
B = len(sample_data)
|
|
|
|
# Get updated minimum sampling denoising std
|
|
min_sampling_denoising_std = self.get_min_sampling_denoising_std()
|
|
|
|
# Loop
|
|
x = torch.randn((B, self.horizon_steps, self.action_dim), device=device)
|
|
if self.use_ddim:
|
|
t_all = self.ddim_t
|
|
else:
|
|
t_all = list(reversed(range(self.denoising_steps)))
|
|
chain = [] if return_chain else None
|
|
if not self.use_ddim and self.ft_denoising_steps == self.denoising_steps:
|
|
chain.append(x)
|
|
if self.use_ddim and self.ft_denoising_steps == self.ddim_steps:
|
|
chain.append(x)
|
|
for i, t in enumerate(t_all):
|
|
t_b = make_timesteps(B, t, device)
|
|
index_b = make_timesteps(B, i, device)
|
|
mean, logvar, _ = self.p_mean_var(
|
|
x=x,
|
|
t=t_b,
|
|
cond=cond,
|
|
index=index_b,
|
|
use_base_policy=use_base_policy,
|
|
deterministic=deterministic,
|
|
)
|
|
std = torch.exp(0.5 * logvar)
|
|
|
|
# Determine noise level
|
|
if self.use_ddim:
|
|
if deterministic:
|
|
std = torch.zeros_like(std)
|
|
else:
|
|
std = torch.clip(std, min=min_sampling_denoising_std)
|
|
else:
|
|
if deterministic and t == 0:
|
|
std = torch.zeros_like(std)
|
|
elif deterministic: # still keep the original noise
|
|
std = torch.clip(std, min=1e-3)
|
|
else: # use higher minimum noise
|
|
std = torch.clip(std, min=min_sampling_denoising_std)
|
|
noise = torch.randn_like(x).clamp_(
|
|
-self.randn_clip_value, self.randn_clip_value
|
|
)
|
|
x = mean + std * noise
|
|
|
|
# clamp action at final step
|
|
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
|
x = torch.clamp(
|
|
x, -self.final_action_clip_value, self.final_action_clip_value
|
|
)
|
|
|
|
if return_chain:
|
|
if not self.use_ddim and t <= self.ft_denoising_steps:
|
|
chain.append(x)
|
|
elif self.use_ddim and i >= (
|
|
self.ddim_steps - self.ft_denoising_steps - 1
|
|
):
|
|
chain.append(x)
|
|
|
|
if return_chain:
|
|
chain = torch.stack(chain, dim=1)
|
|
return Sample(x, chain)
|
|
|
|
# ---------- RL training ----------#
|
|
|
|
def get_logprobs(
|
|
self,
|
|
cond,
|
|
chains,
|
|
get_ent: bool = False,
|
|
use_base_policy: bool = False,
|
|
):
|
|
"""
|
|
Calculating the logprobs of the entire chain of denoised actions.
|
|
|
|
Args:
|
|
cond: dict with key state/rgb; more recent obs at the end
|
|
state: (B, To, Do)
|
|
rgb: (B, To, C, H, W)
|
|
chains: (B, K+1, Ta, Da)
|
|
get_ent: flag for returning entropy
|
|
use_base_policy: flag for using base policy
|
|
|
|
Returns:
|
|
logprobs: (B x K, Ta, Da)
|
|
entropy (if get_ent=True): (B x K, Ta)
|
|
"""
|
|
# Repeat cond for denoising_steps, flatten batch and time dimensions
|
|
cond = {
|
|
key: cond[key]
|
|
.unsqueeze(1)
|
|
.repeat(1, self.ft_denoising_steps, *(1,) * (cond[key].ndim - 1))
|
|
.flatten(start_dim=0, end_dim=1)
|
|
for key in cond
|
|
} # less memory usage than einops?
|
|
|
|
# Repeat t for batch dim, keep it 1-dim
|
|
if self.use_ddim:
|
|
t_single = self.ddim_t[-self.ft_denoising_steps :]
|
|
else:
|
|
t_single = torch.arange(
|
|
start=self.ft_denoising_steps - 1,
|
|
end=-1,
|
|
step=-1,
|
|
device=self.device,
|
|
)
|
|
# 4,3,2,1,0,4,3,2,1,0,...,4,3,2,1,0
|
|
t_all = t_single.repeat(chains.shape[0], 1).flatten()
|
|
if self.use_ddim:
|
|
indices_single = torch.arange(
|
|
start=self.ddim_steps - self.ft_denoising_steps,
|
|
end=self.ddim_steps,
|
|
device=self.device,
|
|
) # only used for DDIM
|
|
indices = indices_single.repeat(chains.shape[0])
|
|
else:
|
|
indices = None
|
|
|
|
# Split chains
|
|
chains_prev = chains[:, :-1]
|
|
chains_next = chains[:, 1:]
|
|
|
|
# Flatten first two dimensions
|
|
chains_prev = chains_prev.reshape(-1, self.horizon_steps, self.action_dim)
|
|
chains_next = chains_next.reshape(-1, self.horizon_steps, self.action_dim)
|
|
|
|
# Forward pass with previous chains
|
|
next_mean, logvar, eta = self.p_mean_var(
|
|
chains_prev,
|
|
t_all,
|
|
cond=cond,
|
|
index=indices,
|
|
use_base_policy=use_base_policy,
|
|
)
|
|
std = torch.exp(0.5 * logvar)
|
|
std = torch.clip(std, min=self.min_logprob_denoising_std)
|
|
dist = Normal(next_mean, std)
|
|
|
|
# Get logprobs with gaussian
|
|
log_prob = dist.log_prob(chains_next)
|
|
if get_ent:
|
|
return log_prob, eta
|
|
return log_prob
|
|
|
|
def get_logprobs_subsample(
|
|
self,
|
|
cond,
|
|
chains_prev,
|
|
chains_next,
|
|
denoising_inds,
|
|
get_ent: bool = False,
|
|
use_base_policy: bool = False,
|
|
):
|
|
"""
|
|
Calculating the logprobs of random samples of denoised chains.
|
|
|
|
Args:
|
|
cond: dict with key state/rgb; more recent obs at the end
|
|
state: (B, To, Do)
|
|
rgb: (B, To, C, H, W)
|
|
chains: (B, K+1, Ta, Da)
|
|
get_ent: flag for returning entropy
|
|
use_base_policy: flag for using base policy
|
|
|
|
Returns:
|
|
logprobs: (B, Ta, Da)
|
|
entropy (if get_ent=True): (B, Ta)
|
|
denoising_indices: (B, )
|
|
"""
|
|
# Sample t for batch dim, keep it 1-dim
|
|
if self.use_ddim:
|
|
t_single = self.ddim_t[-self.ft_denoising_steps :]
|
|
else:
|
|
t_single = torch.arange(
|
|
start=self.ft_denoising_steps - 1,
|
|
end=-1,
|
|
step=-1,
|
|
device=self.device,
|
|
)
|
|
# 4,3,2,1,0,4,3,2,1,0,...,4,3,2,1,0
|
|
t_all = t_single[denoising_inds]
|
|
if self.use_ddim:
|
|
ddim_indices_single = torch.arange(
|
|
start=self.ddim_steps - self.ft_denoising_steps,
|
|
end=self.ddim_steps,
|
|
device=self.device,
|
|
) # only used for DDIM
|
|
ddim_indices = ddim_indices_single[denoising_inds]
|
|
else:
|
|
ddim_indices = None
|
|
|
|
# Forward pass with previous chains
|
|
next_mean, logvar, eta = self.p_mean_var(
|
|
chains_prev,
|
|
t_all,
|
|
cond=cond,
|
|
index=ddim_indices,
|
|
use_base_policy=use_base_policy,
|
|
)
|
|
std = torch.exp(0.5 * logvar)
|
|
std = torch.clip(std, min=self.min_logprob_denoising_std)
|
|
dist = Normal(next_mean, std)
|
|
|
|
# Get logprobs with gaussian
|
|
log_prob = dist.log_prob(chains_next)
|
|
if get_ent:
|
|
return log_prob, eta
|
|
return log_prob
|
|
|
|
def loss(self, cond, chains, reward):
|
|
"""
|
|
REINFORCE loss. Not used right now.
|
|
|
|
Args:
|
|
cond: dict with key state/rgb; more recent obs at the end
|
|
state: (B, To, Do)
|
|
rgb: (B, To, C, H, W)
|
|
chains: (B, K+1, Ta, Da)
|
|
reward (to go): (b,)
|
|
"""
|
|
# Get advantage
|
|
with torch.no_grad():
|
|
value = self.critic(cond).squeeze()
|
|
advantage = reward - value
|
|
|
|
# Get logprobs for denoising steps from T-1 to 0
|
|
logprobs, eta = self.get_logprobs(cond, chains, get_ent=True)
|
|
# (n_steps x n_envs x K) x Ta x (Do+Da)
|
|
|
|
# Ignore obs dimension, and then sum over action dimension
|
|
logprobs = logprobs[:, :, : self.action_dim].sum(-1)
|
|
# -> (n_steps x n_envs x K) x Ta
|
|
|
|
# -> (n_steps x n_envs) x K x Ta
|
|
logprobs = logprobs.reshape((-1, self.denoising_steps, self.horizon_steps))
|
|
|
|
# Sum/avg over denoising steps
|
|
logprobs = logprobs.mean(-2) # -> (n_steps x n_envs) x Ta
|
|
|
|
# Sum/avg over horizon steps
|
|
logprobs = logprobs.mean(-1) # -> (n_steps x n_envs)
|
|
|
|
# Get REINFORCE loss
|
|
loss_actor = torch.mean(-logprobs * advantage)
|
|
|
|
# Train critic to predict state value
|
|
pred = self.critic(cond).squeeze()
|
|
loss_critic = F.mse_loss(pred, reward)
|
|
return loss_actor, loss_critic, eta
|