fix diffusion loss when predicting initial noise
This commit is contained in:
parent
4e14b8086d
commit
7b10df690d
@ -22,6 +22,7 @@ from model.diffusion.sampling import (
|
||||
)
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
Sample = namedtuple("Sample", "trajectories chains")
|
||||
|
||||
|
||||
@ -45,7 +46,7 @@ class DiffusionModel(nn.Module):
|
||||
predict_epsilon=True,
|
||||
# DDIM sampling
|
||||
use_ddim=False,
|
||||
ddim_discretize='uniform',
|
||||
ddim_discretize="uniform",
|
||||
ddim_steps=None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -74,7 +75,9 @@ class DiffusionModel(nn.Module):
|
||||
# Set up models
|
||||
self.network = network.to(device)
|
||||
if network_path is not None:
|
||||
checkpoint = torch.load(network_path, map_location=device, weights_only=True)
|
||||
checkpoint = torch.load(
|
||||
network_path, map_location=device, weights_only=True
|
||||
)
|
||||
if "ema" in checkpoint:
|
||||
self.load_state_dict(checkpoint["ema"], strict=False)
|
||||
logging.info("Loaded SL-trained policy from %s", network_path)
|
||||
@ -104,7 +107,9 @@ class DiffusionModel(nn.Module):
|
||||
"""
|
||||
α̅ₜ₋₁
|
||||
"""
|
||||
self.alphas_cumprod_prev = torch.cat([torch.ones(1).to(device), self.alphas_cumprod[:-1]])
|
||||
self.alphas_cumprod_prev = torch.cat(
|
||||
[torch.ones(1).to(device), self.alphas_cumprod[:-1]]
|
||||
)
|
||||
"""
|
||||
√ α̅ₜ
|
||||
"""
|
||||
@ -131,8 +136,16 @@ class DiffusionModel(nn.Module):
|
||||
"""
|
||||
μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
|
||||
"""
|
||||
self.ddpm_mu_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.ddpm_mu_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
|
||||
self.ddpm_mu_coef1 = (
|
||||
self.betas
|
||||
* torch.sqrt(self.alphas_cumprod_prev)
|
||||
/ (1.0 - self.alphas_cumprod)
|
||||
)
|
||||
self.ddpm_mu_coef2 = (
|
||||
(1.0 - self.alphas_cumprod_prev)
|
||||
* torch.sqrt(self.alphas)
|
||||
/ (1.0 - self.alphas_cumprod)
|
||||
)
|
||||
|
||||
"""
|
||||
DDIM parameters
|
||||
@ -141,30 +154,45 @@ class DiffusionModel(nn.Module):
|
||||
"""
|
||||
if use_ddim:
|
||||
assert predict_epsilon, "DDIM requires predicting epsilon for now."
|
||||
if ddim_discretize == 'uniform': # use the HF "leading" style
|
||||
if ddim_discretize == "uniform": # use the HF "leading" style
|
||||
step_ratio = self.denoising_steps // ddim_steps
|
||||
self.ddim_t = torch.arange(0, ddim_steps, device=self.device) * step_ratio
|
||||
self.ddim_t = (
|
||||
torch.arange(0, ddim_steps, device=self.device) * step_ratio
|
||||
)
|
||||
else:
|
||||
raise 'Unknown discretization method for DDIM.'
|
||||
self.ddim_alphas = self.alphas_cumprod[self.ddim_t].clone().to(torch.float32)
|
||||
raise "Unknown discretization method for DDIM."
|
||||
self.ddim_alphas = (
|
||||
self.alphas_cumprod[self.ddim_t].clone().to(torch.float32)
|
||||
)
|
||||
self.ddim_alphas_sqrt = torch.sqrt(self.ddim_alphas)
|
||||
self.ddim_alphas_prev = torch.cat([
|
||||
torch.tensor([1.]).to(torch.float32).to(self.device),
|
||||
self.alphas_cumprod[self.ddim_t[:-1]]])
|
||||
self.ddim_sqrt_one_minus_alphas = (1. - self.ddim_alphas) ** .5
|
||||
self.ddim_alphas_prev = torch.cat(
|
||||
[
|
||||
torch.tensor([1.0]).to(torch.float32).to(self.device),
|
||||
self.alphas_cumprod[self.ddim_t[:-1]],
|
||||
]
|
||||
)
|
||||
self.ddim_sqrt_one_minus_alphas = (1.0 - self.ddim_alphas) ** 0.5
|
||||
|
||||
# Initialize fixed sigmas for inference - eta=0
|
||||
ddim_eta = 0
|
||||
self.ddim_sigmas = (ddim_eta * \
|
||||
((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * \
|
||||
(1 - self.ddim_alphas / self.ddim_alphas_prev)) ** .5)
|
||||
self.ddim_sigmas = (
|
||||
ddim_eta
|
||||
* (
|
||||
(1 - self.ddim_alphas_prev)
|
||||
/ (1 - self.ddim_alphas)
|
||||
* (1 - self.ddim_alphas / self.ddim_alphas_prev)
|
||||
)
|
||||
** 0.5
|
||||
)
|
||||
|
||||
# Flip all
|
||||
self.ddim_t = torch.flip(self.ddim_t, [0])
|
||||
self.ddim_alphas = torch.flip(self.ddim_alphas, [0])
|
||||
self.ddim_alphas_sqrt = torch.flip(self.ddim_alphas_sqrt, [0])
|
||||
self.ddim_alphas_prev = torch.flip(self.ddim_alphas_prev, [0])
|
||||
self.ddim_sqrt_one_minus_alphas = torch.flip(self.ddim_sqrt_one_minus_alphas, [0])
|
||||
self.ddim_sqrt_one_minus_alphas = torch.flip(
|
||||
self.ddim_sqrt_one_minus_alphas, [0]
|
||||
)
|
||||
self.ddim_sigmas = torch.flip(self.ddim_sigmas, [0])
|
||||
|
||||
# ---------- Sampling ----------#
|
||||
@ -183,8 +211,10 @@ class DiffusionModel(nn.Module):
|
||||
"""
|
||||
alpha = extract(self.ddim_alphas, index, x.shape)
|
||||
alpha_prev = extract(self.ddim_alphas_prev, index, x.shape)
|
||||
sqrt_one_minus_alpha = extract(self.ddim_sqrt_one_minus_alphas, index, x.shape)
|
||||
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha ** 0.5)
|
||||
sqrt_one_minus_alpha = extract(
|
||||
self.ddim_sqrt_one_minus_alphas, index, x.shape
|
||||
)
|
||||
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha**0.5)
|
||||
else:
|
||||
"""
|
||||
x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε
|
||||
@ -193,7 +223,7 @@ class DiffusionModel(nn.Module):
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
||||
- extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise
|
||||
)
|
||||
else: # directly predicting x₀
|
||||
else: # directly predicting x₀
|
||||
x_recon = noise
|
||||
if self.denoised_clip_value is not None:
|
||||
x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value)
|
||||
@ -208,14 +238,14 @@ class DiffusionModel(nn.Module):
|
||||
# Get mu
|
||||
if self.use_ddim:
|
||||
"""
|
||||
μ = √ αₜ₋₁ x₀ + √(1-αₜ₋₁ - σₜ²) ε
|
||||
|
||||
μ = √ αₜ₋₁ x₀ + √(1-αₜ₋₁ - σₜ²) ε
|
||||
|
||||
eta=0
|
||||
"""
|
||||
sigma = extract(self.ddim_sigmas, index, x.shape)
|
||||
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * noise
|
||||
mu = (alpha_prev ** 0.5) * x_recon + dir_xt
|
||||
var = sigma ** 2
|
||||
dir_xt = (1.0 - alpha_prev - sigma**2).sqrt() * noise
|
||||
mu = (alpha_prev**0.5) * x_recon + dir_xt
|
||||
var = sigma**2
|
||||
logvar = torch.log(var)
|
||||
else:
|
||||
"""
|
||||
@ -225,9 +255,7 @@ class DiffusionModel(nn.Module):
|
||||
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
|
||||
+ extract(self.ddpm_mu_coef2, t, x.shape) * x
|
||||
)
|
||||
logvar = extract(
|
||||
self.ddpm_logvar_clipped, t, x.shape
|
||||
)
|
||||
logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
|
||||
return mu, logvar
|
||||
|
||||
@torch.no_grad()
|
||||
@ -279,7 +307,9 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# clamp action at final step
|
||||
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
||||
x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value)
|
||||
x = torch.clamp(
|
||||
x, -self.final_action_clip_value, self.final_action_clip_value
|
||||
)
|
||||
return Sample(x, None)
|
||||
|
||||
# ---------- Supervised training ----------#
|
||||
@ -314,9 +344,9 @@ class DiffusionModel(nn.Module):
|
||||
# Predict
|
||||
x_recon = self.network(x_noisy, t, cond=cond)
|
||||
if self.predict_epsilon:
|
||||
return F.mse_loss(x_recon, noise, reduction="mean")
|
||||
return F.mse_loss(x_recon, noise, reduction="mean")
|
||||
else:
|
||||
return F.mse_loss(x_recon, x_noisy, reduction="mean")
|
||||
return F.mse_loss(x_recon, x_start, reduction="mean")
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user