fix diffusion loss when predicting initial noise

This commit is contained in:
allenzren 2024-10-13 11:19:10 -04:00
parent 4e14b8086d
commit 7b10df690d

View File

@ -22,6 +22,7 @@ from model.diffusion.sampling import (
)
from collections import namedtuple
Sample = namedtuple("Sample", "trajectories chains")
@ -45,7 +46,7 @@ class DiffusionModel(nn.Module):
predict_epsilon=True,
# DDIM sampling
use_ddim=False,
ddim_discretize='uniform',
ddim_discretize="uniform",
ddim_steps=None,
**kwargs,
):
@ -74,7 +75,9 @@ class DiffusionModel(nn.Module):
# Set up models
self.network = network.to(device)
if network_path is not None:
checkpoint = torch.load(network_path, map_location=device, weights_only=True)
checkpoint = torch.load(
network_path, map_location=device, weights_only=True
)
if "ema" in checkpoint:
self.load_state_dict(checkpoint["ema"], strict=False)
logging.info("Loaded SL-trained policy from %s", network_path)
@ -104,7 +107,9 @@ class DiffusionModel(nn.Module):
"""
α̅
"""
self.alphas_cumprod_prev = torch.cat([torch.ones(1).to(device), self.alphas_cumprod[:-1]])
self.alphas_cumprod_prev = torch.cat(
[torch.ones(1).to(device), self.alphas_cumprod[:-1]]
)
"""
α̅
"""
@ -131,8 +136,16 @@ class DiffusionModel(nn.Module):
"""
μₜ = β̃ α̅/(1-α̅)x₀ + αₜ (1-α̅)/(1-α̅)xₜ
"""
self.ddpm_mu_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.ddpm_mu_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
self.ddpm_mu_coef1 = (
self.betas
* torch.sqrt(self.alphas_cumprod_prev)
/ (1.0 - self.alphas_cumprod)
)
self.ddpm_mu_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* torch.sqrt(self.alphas)
/ (1.0 - self.alphas_cumprod)
)
"""
DDIM parameters
@ -141,30 +154,45 @@ class DiffusionModel(nn.Module):
"""
if use_ddim:
assert predict_epsilon, "DDIM requires predicting epsilon for now."
if ddim_discretize == 'uniform': # use the HF "leading" style
if ddim_discretize == "uniform": # use the HF "leading" style
step_ratio = self.denoising_steps // ddim_steps
self.ddim_t = torch.arange(0, ddim_steps, device=self.device) * step_ratio
self.ddim_t = (
torch.arange(0, ddim_steps, device=self.device) * step_ratio
)
else:
raise 'Unknown discretization method for DDIM.'
self.ddim_alphas = self.alphas_cumprod[self.ddim_t].clone().to(torch.float32)
raise "Unknown discretization method for DDIM."
self.ddim_alphas = (
self.alphas_cumprod[self.ddim_t].clone().to(torch.float32)
)
self.ddim_alphas_sqrt = torch.sqrt(self.ddim_alphas)
self.ddim_alphas_prev = torch.cat([
torch.tensor([1.]).to(torch.float32).to(self.device),
self.alphas_cumprod[self.ddim_t[:-1]]])
self.ddim_sqrt_one_minus_alphas = (1. - self.ddim_alphas) ** .5
self.ddim_alphas_prev = torch.cat(
[
torch.tensor([1.0]).to(torch.float32).to(self.device),
self.alphas_cumprod[self.ddim_t[:-1]],
]
)
self.ddim_sqrt_one_minus_alphas = (1.0 - self.ddim_alphas) ** 0.5
# Initialize fixed sigmas for inference - eta=0
ddim_eta = 0
self.ddim_sigmas = (ddim_eta * \
((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * \
(1 - self.ddim_alphas / self.ddim_alphas_prev)) ** .5)
self.ddim_sigmas = (
ddim_eta
* (
(1 - self.ddim_alphas_prev)
/ (1 - self.ddim_alphas)
* (1 - self.ddim_alphas / self.ddim_alphas_prev)
)
** 0.5
)
# Flip all
self.ddim_t = torch.flip(self.ddim_t, [0])
self.ddim_alphas = torch.flip(self.ddim_alphas, [0])
self.ddim_alphas_sqrt = torch.flip(self.ddim_alphas_sqrt, [0])
self.ddim_alphas_prev = torch.flip(self.ddim_alphas_prev, [0])
self.ddim_sqrt_one_minus_alphas = torch.flip(self.ddim_sqrt_one_minus_alphas, [0])
self.ddim_sqrt_one_minus_alphas = torch.flip(
self.ddim_sqrt_one_minus_alphas, [0]
)
self.ddim_sigmas = torch.flip(self.ddim_sigmas, [0])
# ---------- Sampling ----------#
@ -183,8 +211,10 @@ class DiffusionModel(nn.Module):
"""
alpha = extract(self.ddim_alphas, index, x.shape)
alpha_prev = extract(self.ddim_alphas_prev, index, x.shape)
sqrt_one_minus_alpha = extract(self.ddim_sqrt_one_minus_alphas, index, x.shape)
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha ** 0.5)
sqrt_one_minus_alpha = extract(
self.ddim_sqrt_one_minus_alphas, index, x.shape
)
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha**0.5)
else:
"""
x₀ = 1\α̅ xₜ - 1\α̅-1 ε
@ -193,7 +223,7 @@ class DiffusionModel(nn.Module):
extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise
)
else: # directly predicting x₀
else: # directly predicting x₀
x_recon = noise
if self.denoised_clip_value is not None:
x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value)
@ -208,14 +238,14 @@ class DiffusionModel(nn.Module):
# Get mu
if self.use_ddim:
"""
μ = αₜ x₀ + (1-αₜ - σₜ²) ε
μ = αₜ x₀ + (1-αₜ - σₜ²) ε
eta=0
"""
sigma = extract(self.ddim_sigmas, index, x.shape)
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * noise
mu = (alpha_prev ** 0.5) * x_recon + dir_xt
var = sigma ** 2
dir_xt = (1.0 - alpha_prev - sigma**2).sqrt() * noise
mu = (alpha_prev**0.5) * x_recon + dir_xt
var = sigma**2
logvar = torch.log(var)
else:
"""
@ -225,9 +255,7 @@ class DiffusionModel(nn.Module):
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
+ extract(self.ddpm_mu_coef2, t, x.shape) * x
)
logvar = extract(
self.ddpm_logvar_clipped, t, x.shape
)
logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
return mu, logvar
@torch.no_grad()
@ -279,7 +307,9 @@ class DiffusionModel(nn.Module):
# clamp action at final step
if self.final_action_clip_value is not None and i == len(t_all) - 1:
x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value)
x = torch.clamp(
x, -self.final_action_clip_value, self.final_action_clip_value
)
return Sample(x, None)
# ---------- Supervised training ----------#
@ -314,9 +344,9 @@ class DiffusionModel(nn.Module):
# Predict
x_recon = self.network(x_noisy, t, cond=cond)
if self.predict_epsilon:
return F.mse_loss(x_recon, noise, reduction="mean")
return F.mse_loss(x_recon, noise, reduction="mean")
else:
return F.mse_loss(x_recon, x_noisy, reduction="mean")
return F.mse_loss(x_recon, x_start, reduction="mean")
def q_sample(self, x_start, t, noise=None):
"""