dppo/model/diffusion/diffusion.py

364 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Gaussian diffusion with DDPM and optionally DDIM sampling.
References:
Diffuser: https://github.com/jannerm/diffuser
Diffusion Policy: https://github.com/columbia-ai-robotics/diffusion_policy/blob/main/diffusion_policy/policy/diffusion_unet_lowdim_policy.py
Annotated DDIM/DDPM: https://nn.labml.ai/diffusion/stable_diffusion/sampler/ddpm.html
"""
import logging
import torch
from torch import nn
import torch.nn.functional as F
log = logging.getLogger(__name__)
from model.diffusion.sampling import (
extract,
cosine_beta_schedule,
make_timesteps,
)
from collections import namedtuple
Sample = namedtuple("Sample", "trajectories chains")
class DiffusionModel(nn.Module):
def __init__(
self,
network,
horizon_steps,
obs_dim,
action_dim,
network_path=None,
device="cuda:0",
# Various clipping
denoised_clip_value=1.0,
randn_clip_value=10,
final_action_clip_value=None,
eps_clip_value=None, # DDIM only
# DDPM parameters
denoising_steps=100,
predict_epsilon=True,
# DDIM sampling
use_ddim=False,
ddim_discretize="uniform",
ddim_steps=None,
**kwargs,
):
super().__init__()
self.device = device
self.horizon_steps = horizon_steps
self.obs_dim = obs_dim
self.action_dim = action_dim
self.denoising_steps = int(denoising_steps)
self.predict_epsilon = predict_epsilon
self.use_ddim = use_ddim
self.ddim_steps = ddim_steps
# Clip noise value at each denoising step
self.denoised_clip_value = denoised_clip_value
# Whether to clamp the final sampled action between [-1, 1]
self.final_action_clip_value = final_action_clip_value
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
self.randn_clip_value = randn_clip_value
# Clip epsilon for numerical stability
self.eps_clip_value = eps_clip_value
# Set up models
self.network = network.to(device)
if network_path is not None:
checkpoint = torch.load(
network_path, map_location=device, weights_only=True
)
if "ema" in checkpoint:
self.load_state_dict(checkpoint["ema"], strict=False)
logging.info("Loaded SL-trained policy from %s", network_path)
else:
self.load_state_dict(checkpoint["model"], strict=False)
logging.info("Loaded RL-trained policy from %s", network_path)
logging.info(
f"Number of network parameters: {sum(p.numel() for p in self.parameters())}"
)
"""
DDPM parameters
"""
"""
βₜ
"""
self.betas = cosine_beta_schedule(denoising_steps).to(device)
"""
αₜ = 1 - βₜ
"""
self.alphas = 1.0 - self.betas
"""
α̅ₜ= ∏ᵗₛ₌₁ αₛ
"""
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
"""
α̅ₜ₋₁
"""
self.alphas_cumprod_prev = torch.cat(
[torch.ones(1).to(device), self.alphas_cumprod[:-1]]
)
"""
√ α̅ₜ
"""
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
"""
√ 1-α̅ₜ
"""
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
"""
√ 1\α̅ₜ
"""
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
"""
√ 1\α̅ₜ-1
"""
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
"""
β̃ₜ = σₜ² = βₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)
"""
self.ddpm_var = (
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.ddpm_logvar_clipped = torch.log(torch.clamp(self.ddpm_var, min=1e-20))
"""
μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
"""
self.ddpm_mu_coef1 = (
self.betas
* torch.sqrt(self.alphas_cumprod_prev)
/ (1.0 - self.alphas_cumprod)
)
self.ddpm_mu_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* torch.sqrt(self.alphas)
/ (1.0 - self.alphas_cumprod)
)
"""
DDIM parameters
In DDIM paper https://arxiv.org/pdf/2010.02502, alpha is alpha_cumprod in DDPM https://arxiv.org/pdf/2102.09672
"""
if use_ddim:
assert predict_epsilon, "DDIM requires predicting epsilon for now."
if ddim_discretize == "uniform": # use the HF "leading" style
step_ratio = self.denoising_steps // ddim_steps
self.ddim_t = (
torch.arange(0, ddim_steps, device=self.device) * step_ratio
)
else:
raise "Unknown discretization method for DDIM."
self.ddim_alphas = (
self.alphas_cumprod[self.ddim_t].clone().to(torch.float32)
)
self.ddim_alphas_sqrt = torch.sqrt(self.ddim_alphas)
self.ddim_alphas_prev = torch.cat(
[
torch.tensor([1.0]).to(torch.float32).to(self.device),
self.alphas_cumprod[self.ddim_t[:-1]],
]
)
self.ddim_sqrt_one_minus_alphas = (1.0 - self.ddim_alphas) ** 0.5
# Initialize fixed sigmas for inference - eta=0
ddim_eta = 0
self.ddim_sigmas = (
ddim_eta
* (
(1 - self.ddim_alphas_prev)
/ (1 - self.ddim_alphas)
* (1 - self.ddim_alphas / self.ddim_alphas_prev)
)
** 0.5
)
# Flip all
self.ddim_t = torch.flip(self.ddim_t, [0])
self.ddim_alphas = torch.flip(self.ddim_alphas, [0])
self.ddim_alphas_sqrt = torch.flip(self.ddim_alphas_sqrt, [0])
self.ddim_alphas_prev = torch.flip(self.ddim_alphas_prev, [0])
self.ddim_sqrt_one_minus_alphas = torch.flip(
self.ddim_sqrt_one_minus_alphas, [0]
)
self.ddim_sigmas = torch.flip(self.ddim_sigmas, [0])
# ---------- Sampling ----------#
def p_mean_var(self, x, t, cond, index=None, network_override=None):
if network_override is not None:
noise = network_override(x, t, cond=cond)
else:
noise = self.network(x, t, cond=cond)
# Predict x_0
if self.predict_epsilon:
if self.use_ddim:
"""
x₀ = (xₜ - √ (1-αₜ) ε )/ √ αₜ
"""
alpha = extract(self.ddim_alphas, index, x.shape)
alpha_prev = extract(self.ddim_alphas_prev, index, x.shape)
sqrt_one_minus_alpha = extract(
self.ddim_sqrt_one_minus_alphas, index, x.shape
)
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha**0.5)
else:
"""
x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε
"""
x_recon = (
extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise
)
else: # directly predicting x₀
x_recon = noise
if self.denoised_clip_value is not None:
x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value)
if self.use_ddim:
# re-calculate noise based on clamped x_recon - default to false in HF, but let's use it here
noise = (x - alpha ** (0.5) * x_recon) / sqrt_one_minus_alpha
# Clip epsilon for numerical stability in policy gradient - not sure if this is helpful yet, but the value can be huge sometimes. This has no effect if DDPM is used
if self.use_ddim and self.eps_clip_value is not None:
noise.clamp_(-self.eps_clip_value, self.eps_clip_value)
# Get mu
if self.use_ddim:
"""
μ = √ αₜ₋₁ x₀ + √(1-αₜ₋₁ - σₜ²) ε
eta=0
"""
sigma = extract(self.ddim_sigmas, index, x.shape)
dir_xt = (1.0 - alpha_prev - sigma**2).sqrt() * noise
mu = (alpha_prev**0.5) * x_recon + dir_xt
var = sigma**2
logvar = torch.log(var)
else:
"""
μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
"""
mu = (
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
+ extract(self.ddpm_mu_coef2, t, x.shape) * x
)
logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
return mu, logvar
@torch.no_grad()
def forward(self, cond, deterministic=True):
"""
Forward pass for sampling actions. Used in evaluating pre-trained/fine-tuned policy. Not modifying diffusion clipping
Args:
cond: dict with key state/rgb; more recent obs at the end
state: (B, To, Do)
rgb: (B, To, C, H, W)
Return:
Sample: namedtuple with fields:
trajectories: (B, Ta, Da)
"""
device = self.betas.device
sample_data = cond["state"] if "state" in cond else cond["rgb"]
B = len(sample_data)
# Loop
x = torch.randn((B, self.horizon_steps, self.action_dim), device=device)
if self.use_ddim:
t_all = self.ddim_t
else:
t_all = list(reversed(range(self.denoising_steps)))
for i, t in enumerate(t_all):
t_b = make_timesteps(B, t, device)
index_b = make_timesteps(B, i, device)
mean, logvar = self.p_mean_var(
x=x,
t=t_b,
cond=cond,
index=index_b,
deterministic=deterministic,
)
std = torch.exp(0.5 * logvar)
# Determine noise level
if self.use_ddim:
std = torch.zeros_like(std)
else:
if t == 0:
std = torch.zeros_like(std)
else:
std = torch.clip(std, min=1e-3)
noise = torch.randn_like(x).clamp_(
-self.randn_clip_value, self.randn_clip_value
)
x = mean + std * noise
# clamp action at final step
if self.final_action_clip_value is not None and i == len(t_all) - 1:
x = torch.clamp(
x, -self.final_action_clip_value, self.final_action_clip_value
)
return Sample(x, None)
# ---------- Supervised training ----------#
def loss(self, x, *args):
batch_size = len(x)
t = torch.randint(
0, self.denoising_steps, (batch_size,), device=x.device
).long()
return self.p_losses(x, *args, t)
def p_losses(
self,
x_start,
cond: dict,
t,
):
"""
If predicting epsilon: E_{t, x0, ε} [||ε - ε_θ(√α̅ₜx0 + √(1-α̅ₜ)ε, t)||²
Args:
x_start: (batch_size, horizon_steps, action_dim)
cond: dict with keys as step and value as observation
t: batch of integers
"""
device = x_start.device
# Forward process
noise = torch.randn_like(x_start, device=device)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
# Predict
x_recon = self.network(x_noisy, t, cond=cond)
if self.predict_epsilon:
return F.mse_loss(x_recon, noise, reduction="mean")
else:
return F.mse_loss(x_recon, x_start, reduction="mean")
def q_sample(self, x_start, t, noise=None):
"""
q(xₜ | x₀) = 𝒩(xₜ; √ α̅ₜ x₀, (1-α̅ₜ)I)
xₜ = √ α̅ₜ xₒ + √ (1-α̅ₜ) ε
"""
if noise is None:
device = x_start.device
noise = torch.randn_like(x_start, device=device)
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)