364 lines
12 KiB
Python
364 lines
12 KiB
Python
"""
|
||
Gaussian diffusion with DDPM and optionally DDIM sampling.
|
||
|
||
References:
|
||
Diffuser: https://github.com/jannerm/diffuser
|
||
Diffusion Policy: https://github.com/columbia-ai-robotics/diffusion_policy/blob/main/diffusion_policy/policy/diffusion_unet_lowdim_policy.py
|
||
Annotated DDIM/DDPM: https://nn.labml.ai/diffusion/stable_diffusion/sampler/ddpm.html
|
||
|
||
"""
|
||
|
||
import logging
|
||
import torch
|
||
from torch import nn
|
||
import torch.nn.functional as F
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
from model.diffusion.sampling import (
|
||
extract,
|
||
cosine_beta_schedule,
|
||
make_timesteps,
|
||
)
|
||
|
||
from collections import namedtuple
|
||
|
||
Sample = namedtuple("Sample", "trajectories chains")
|
||
|
||
|
||
class DiffusionModel(nn.Module):
|
||
|
||
def __init__(
|
||
self,
|
||
network,
|
||
horizon_steps,
|
||
obs_dim,
|
||
action_dim,
|
||
network_path=None,
|
||
device="cuda:0",
|
||
# Various clipping
|
||
denoised_clip_value=1.0,
|
||
randn_clip_value=10,
|
||
final_action_clip_value=None,
|
||
eps_clip_value=None, # DDIM only
|
||
# DDPM parameters
|
||
denoising_steps=100,
|
||
predict_epsilon=True,
|
||
# DDIM sampling
|
||
use_ddim=False,
|
||
ddim_discretize="uniform",
|
||
ddim_steps=None,
|
||
**kwargs,
|
||
):
|
||
super().__init__()
|
||
self.device = device
|
||
self.horizon_steps = horizon_steps
|
||
self.obs_dim = obs_dim
|
||
self.action_dim = action_dim
|
||
self.denoising_steps = int(denoising_steps)
|
||
self.predict_epsilon = predict_epsilon
|
||
self.use_ddim = use_ddim
|
||
self.ddim_steps = ddim_steps
|
||
|
||
# Clip noise value at each denoising step
|
||
self.denoised_clip_value = denoised_clip_value
|
||
|
||
# Whether to clamp the final sampled action between [-1, 1]
|
||
self.final_action_clip_value = final_action_clip_value
|
||
|
||
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
||
self.randn_clip_value = randn_clip_value
|
||
|
||
# Clip epsilon for numerical stability
|
||
self.eps_clip_value = eps_clip_value
|
||
|
||
# Set up models
|
||
self.network = network.to(device)
|
||
if network_path is not None:
|
||
checkpoint = torch.load(
|
||
network_path, map_location=device, weights_only=True
|
||
)
|
||
if "ema" in checkpoint:
|
||
self.load_state_dict(checkpoint["ema"], strict=False)
|
||
logging.info("Loaded SL-trained policy from %s", network_path)
|
||
else:
|
||
self.load_state_dict(checkpoint["model"], strict=False)
|
||
logging.info("Loaded RL-trained policy from %s", network_path)
|
||
logging.info(
|
||
f"Number of network parameters: {sum(p.numel() for p in self.parameters())}"
|
||
)
|
||
|
||
"""
|
||
DDPM parameters
|
||
|
||
"""
|
||
"""
|
||
βₜ
|
||
"""
|
||
self.betas = cosine_beta_schedule(denoising_steps).to(device)
|
||
"""
|
||
αₜ = 1 - βₜ
|
||
"""
|
||
self.alphas = 1.0 - self.betas
|
||
"""
|
||
α̅ₜ= ∏ᵗₛ₌₁ αₛ
|
||
"""
|
||
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
|
||
"""
|
||
α̅ₜ₋₁
|
||
"""
|
||
self.alphas_cumprod_prev = torch.cat(
|
||
[torch.ones(1).to(device), self.alphas_cumprod[:-1]]
|
||
)
|
||
"""
|
||
√ α̅ₜ
|
||
"""
|
||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||
"""
|
||
√ 1-α̅ₜ
|
||
"""
|
||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
|
||
"""
|
||
√ 1\α̅ₜ
|
||
"""
|
||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
|
||
"""
|
||
√ 1\α̅ₜ-1
|
||
"""
|
||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
|
||
"""
|
||
β̃ₜ = σₜ² = βₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)
|
||
"""
|
||
self.ddpm_var = (
|
||
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||
)
|
||
self.ddpm_logvar_clipped = torch.log(torch.clamp(self.ddpm_var, min=1e-20))
|
||
"""
|
||
μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
|
||
"""
|
||
self.ddpm_mu_coef1 = (
|
||
self.betas
|
||
* torch.sqrt(self.alphas_cumprod_prev)
|
||
/ (1.0 - self.alphas_cumprod)
|
||
)
|
||
self.ddpm_mu_coef2 = (
|
||
(1.0 - self.alphas_cumprod_prev)
|
||
* torch.sqrt(self.alphas)
|
||
/ (1.0 - self.alphas_cumprod)
|
||
)
|
||
|
||
"""
|
||
DDIM parameters
|
||
|
||
In DDIM paper https://arxiv.org/pdf/2010.02502, alpha is alpha_cumprod in DDPM https://arxiv.org/pdf/2102.09672
|
||
"""
|
||
if use_ddim:
|
||
assert predict_epsilon, "DDIM requires predicting epsilon for now."
|
||
if ddim_discretize == "uniform": # use the HF "leading" style
|
||
step_ratio = self.denoising_steps // ddim_steps
|
||
self.ddim_t = (
|
||
torch.arange(0, ddim_steps, device=self.device) * step_ratio
|
||
)
|
||
else:
|
||
raise "Unknown discretization method for DDIM."
|
||
self.ddim_alphas = (
|
||
self.alphas_cumprod[self.ddim_t].clone().to(torch.float32)
|
||
)
|
||
self.ddim_alphas_sqrt = torch.sqrt(self.ddim_alphas)
|
||
self.ddim_alphas_prev = torch.cat(
|
||
[
|
||
torch.tensor([1.0]).to(torch.float32).to(self.device),
|
||
self.alphas_cumprod[self.ddim_t[:-1]],
|
||
]
|
||
)
|
||
self.ddim_sqrt_one_minus_alphas = (1.0 - self.ddim_alphas) ** 0.5
|
||
|
||
# Initialize fixed sigmas for inference - eta=0
|
||
ddim_eta = 0
|
||
self.ddim_sigmas = (
|
||
ddim_eta
|
||
* (
|
||
(1 - self.ddim_alphas_prev)
|
||
/ (1 - self.ddim_alphas)
|
||
* (1 - self.ddim_alphas / self.ddim_alphas_prev)
|
||
)
|
||
** 0.5
|
||
)
|
||
|
||
# Flip all
|
||
self.ddim_t = torch.flip(self.ddim_t, [0])
|
||
self.ddim_alphas = torch.flip(self.ddim_alphas, [0])
|
||
self.ddim_alphas_sqrt = torch.flip(self.ddim_alphas_sqrt, [0])
|
||
self.ddim_alphas_prev = torch.flip(self.ddim_alphas_prev, [0])
|
||
self.ddim_sqrt_one_minus_alphas = torch.flip(
|
||
self.ddim_sqrt_one_minus_alphas, [0]
|
||
)
|
||
self.ddim_sigmas = torch.flip(self.ddim_sigmas, [0])
|
||
|
||
# ---------- Sampling ----------#
|
||
|
||
def p_mean_var(self, x, t, cond, index=None, network_override=None):
|
||
if network_override is not None:
|
||
noise = network_override(x, t, cond=cond)
|
||
else:
|
||
noise = self.network(x, t, cond=cond)
|
||
|
||
# Predict x_0
|
||
if self.predict_epsilon:
|
||
if self.use_ddim:
|
||
"""
|
||
x₀ = (xₜ - √ (1-αₜ) ε )/ √ αₜ
|
||
"""
|
||
alpha = extract(self.ddim_alphas, index, x.shape)
|
||
alpha_prev = extract(self.ddim_alphas_prev, index, x.shape)
|
||
sqrt_one_minus_alpha = extract(
|
||
self.ddim_sqrt_one_minus_alphas, index, x.shape
|
||
)
|
||
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha**0.5)
|
||
else:
|
||
"""
|
||
x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε
|
||
"""
|
||
x_recon = (
|
||
extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
||
- extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise
|
||
)
|
||
else: # directly predicting x₀
|
||
x_recon = noise
|
||
if self.denoised_clip_value is not None:
|
||
x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value)
|
||
if self.use_ddim:
|
||
# re-calculate noise based on clamped x_recon - default to false in HF, but let's use it here
|
||
noise = (x - alpha ** (0.5) * x_recon) / sqrt_one_minus_alpha
|
||
|
||
# Clip epsilon for numerical stability in policy gradient - not sure if this is helpful yet, but the value can be huge sometimes. This has no effect if DDPM is used
|
||
if self.use_ddim and self.eps_clip_value is not None:
|
||
noise.clamp_(-self.eps_clip_value, self.eps_clip_value)
|
||
|
||
# Get mu
|
||
if self.use_ddim:
|
||
"""
|
||
μ = √ αₜ₋₁ x₀ + √(1-αₜ₋₁ - σₜ²) ε
|
||
|
||
eta=0
|
||
"""
|
||
sigma = extract(self.ddim_sigmas, index, x.shape)
|
||
dir_xt = (1.0 - alpha_prev - sigma**2).sqrt() * noise
|
||
mu = (alpha_prev**0.5) * x_recon + dir_xt
|
||
var = sigma**2
|
||
logvar = torch.log(var)
|
||
else:
|
||
"""
|
||
μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
|
||
"""
|
||
mu = (
|
||
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
|
||
+ extract(self.ddpm_mu_coef2, t, x.shape) * x
|
||
)
|
||
logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
|
||
return mu, logvar
|
||
|
||
@torch.no_grad()
|
||
def forward(self, cond, deterministic=True):
|
||
"""
|
||
Forward pass for sampling actions. Used in evaluating pre-trained/fine-tuned policy. Not modifying diffusion clipping
|
||
|
||
Args:
|
||
cond: dict with key state/rgb; more recent obs at the end
|
||
state: (B, To, Do)
|
||
rgb: (B, To, C, H, W)
|
||
Return:
|
||
Sample: namedtuple with fields:
|
||
trajectories: (B, Ta, Da)
|
||
"""
|
||
device = self.betas.device
|
||
sample_data = cond["state"] if "state" in cond else cond["rgb"]
|
||
B = len(sample_data)
|
||
|
||
# Loop
|
||
x = torch.randn((B, self.horizon_steps, self.action_dim), device=device)
|
||
if self.use_ddim:
|
||
t_all = self.ddim_t
|
||
else:
|
||
t_all = list(reversed(range(self.denoising_steps)))
|
||
for i, t in enumerate(t_all):
|
||
t_b = make_timesteps(B, t, device)
|
||
index_b = make_timesteps(B, i, device)
|
||
mean, logvar = self.p_mean_var(
|
||
x=x,
|
||
t=t_b,
|
||
cond=cond,
|
||
index=index_b,
|
||
deterministic=deterministic,
|
||
)
|
||
std = torch.exp(0.5 * logvar)
|
||
|
||
# Determine noise level
|
||
if self.use_ddim:
|
||
std = torch.zeros_like(std)
|
||
else:
|
||
if t == 0:
|
||
std = torch.zeros_like(std)
|
||
else:
|
||
std = torch.clip(std, min=1e-3)
|
||
noise = torch.randn_like(x).clamp_(
|
||
-self.randn_clip_value, self.randn_clip_value
|
||
)
|
||
x = mean + std * noise
|
||
|
||
# clamp action at final step
|
||
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
||
x = torch.clamp(
|
||
x, -self.final_action_clip_value, self.final_action_clip_value
|
||
)
|
||
return Sample(x, None)
|
||
|
||
# ---------- Supervised training ----------#
|
||
|
||
def loss(self, x, *args):
|
||
batch_size = len(x)
|
||
t = torch.randint(
|
||
0, self.denoising_steps, (batch_size,), device=x.device
|
||
).long()
|
||
return self.p_losses(x, *args, t)
|
||
|
||
def p_losses(
|
||
self,
|
||
x_start,
|
||
cond: dict,
|
||
t,
|
||
):
|
||
"""
|
||
If predicting epsilon: E_{t, x0, ε} [||ε - ε_θ(√α̅ₜx0 + √(1-α̅ₜ)ε, t)||²
|
||
|
||
Args:
|
||
x_start: (batch_size, horizon_steps, action_dim)
|
||
cond: dict with keys as step and value as observation
|
||
t: batch of integers
|
||
"""
|
||
device = x_start.device
|
||
|
||
# Forward process
|
||
noise = torch.randn_like(x_start, device=device)
|
||
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
||
|
||
# Predict
|
||
x_recon = self.network(x_noisy, t, cond=cond)
|
||
if self.predict_epsilon:
|
||
return F.mse_loss(x_recon, noise, reduction="mean")
|
||
else:
|
||
return F.mse_loss(x_recon, x_start, reduction="mean")
|
||
|
||
def q_sample(self, x_start, t, noise=None):
|
||
"""
|
||
q(xₜ | x₀) = 𝒩(xₜ; √ α̅ₜ x₀, (1-α̅ₜ)I)
|
||
xₜ = √ α̅ₜ xₒ + √ (1-α̅ₜ) ε
|
||
"""
|
||
if noise is None:
|
||
device = x_start.device
|
||
noise = torch.randn_like(x_start, device=device)
|
||
return (
|
||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||
)
|