fix diffusion loss when predicting initial noise
This commit is contained in:
parent
4e14b8086d
commit
7b10df690d
@ -22,6 +22,7 @@ from model.diffusion.sampling import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
Sample = namedtuple("Sample", "trajectories chains")
|
Sample = namedtuple("Sample", "trajectories chains")
|
||||||
|
|
||||||
|
|
||||||
@ -45,7 +46,7 @@ class DiffusionModel(nn.Module):
|
|||||||
predict_epsilon=True,
|
predict_epsilon=True,
|
||||||
# DDIM sampling
|
# DDIM sampling
|
||||||
use_ddim=False,
|
use_ddim=False,
|
||||||
ddim_discretize='uniform',
|
ddim_discretize="uniform",
|
||||||
ddim_steps=None,
|
ddim_steps=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -74,7 +75,9 @@ class DiffusionModel(nn.Module):
|
|||||||
# Set up models
|
# Set up models
|
||||||
self.network = network.to(device)
|
self.network = network.to(device)
|
||||||
if network_path is not None:
|
if network_path is not None:
|
||||||
checkpoint = torch.load(network_path, map_location=device, weights_only=True)
|
checkpoint = torch.load(
|
||||||
|
network_path, map_location=device, weights_only=True
|
||||||
|
)
|
||||||
if "ema" in checkpoint:
|
if "ema" in checkpoint:
|
||||||
self.load_state_dict(checkpoint["ema"], strict=False)
|
self.load_state_dict(checkpoint["ema"], strict=False)
|
||||||
logging.info("Loaded SL-trained policy from %s", network_path)
|
logging.info("Loaded SL-trained policy from %s", network_path)
|
||||||
@ -104,7 +107,9 @@ class DiffusionModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
α̅ₜ₋₁
|
α̅ₜ₋₁
|
||||||
"""
|
"""
|
||||||
self.alphas_cumprod_prev = torch.cat([torch.ones(1).to(device), self.alphas_cumprod[:-1]])
|
self.alphas_cumprod_prev = torch.cat(
|
||||||
|
[torch.ones(1).to(device), self.alphas_cumprod[:-1]]
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
√ α̅ₜ
|
√ α̅ₜ
|
||||||
"""
|
"""
|
||||||
@ -131,8 +136,16 @@ class DiffusionModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
|
μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
|
||||||
"""
|
"""
|
||||||
self.ddpm_mu_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
self.ddpm_mu_coef1 = (
|
||||||
self.ddpm_mu_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
|
self.betas
|
||||||
|
* torch.sqrt(self.alphas_cumprod_prev)
|
||||||
|
/ (1.0 - self.alphas_cumprod)
|
||||||
|
)
|
||||||
|
self.ddpm_mu_coef2 = (
|
||||||
|
(1.0 - self.alphas_cumprod_prev)
|
||||||
|
* torch.sqrt(self.alphas)
|
||||||
|
/ (1.0 - self.alphas_cumprod)
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
DDIM parameters
|
DDIM parameters
|
||||||
@ -141,30 +154,45 @@ class DiffusionModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if use_ddim:
|
if use_ddim:
|
||||||
assert predict_epsilon, "DDIM requires predicting epsilon for now."
|
assert predict_epsilon, "DDIM requires predicting epsilon for now."
|
||||||
if ddim_discretize == 'uniform': # use the HF "leading" style
|
if ddim_discretize == "uniform": # use the HF "leading" style
|
||||||
step_ratio = self.denoising_steps // ddim_steps
|
step_ratio = self.denoising_steps // ddim_steps
|
||||||
self.ddim_t = torch.arange(0, ddim_steps, device=self.device) * step_ratio
|
self.ddim_t = (
|
||||||
|
torch.arange(0, ddim_steps, device=self.device) * step_ratio
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise 'Unknown discretization method for DDIM.'
|
raise "Unknown discretization method for DDIM."
|
||||||
self.ddim_alphas = self.alphas_cumprod[self.ddim_t].clone().to(torch.float32)
|
self.ddim_alphas = (
|
||||||
|
self.alphas_cumprod[self.ddim_t].clone().to(torch.float32)
|
||||||
|
)
|
||||||
self.ddim_alphas_sqrt = torch.sqrt(self.ddim_alphas)
|
self.ddim_alphas_sqrt = torch.sqrt(self.ddim_alphas)
|
||||||
self.ddim_alphas_prev = torch.cat([
|
self.ddim_alphas_prev = torch.cat(
|
||||||
torch.tensor([1.]).to(torch.float32).to(self.device),
|
[
|
||||||
self.alphas_cumprod[self.ddim_t[:-1]]])
|
torch.tensor([1.0]).to(torch.float32).to(self.device),
|
||||||
self.ddim_sqrt_one_minus_alphas = (1. - self.ddim_alphas) ** .5
|
self.alphas_cumprod[self.ddim_t[:-1]],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ddim_sqrt_one_minus_alphas = (1.0 - self.ddim_alphas) ** 0.5
|
||||||
|
|
||||||
# Initialize fixed sigmas for inference - eta=0
|
# Initialize fixed sigmas for inference - eta=0
|
||||||
ddim_eta = 0
|
ddim_eta = 0
|
||||||
self.ddim_sigmas = (ddim_eta * \
|
self.ddim_sigmas = (
|
||||||
((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * \
|
ddim_eta
|
||||||
(1 - self.ddim_alphas / self.ddim_alphas_prev)) ** .5)
|
* (
|
||||||
|
(1 - self.ddim_alphas_prev)
|
||||||
|
/ (1 - self.ddim_alphas)
|
||||||
|
* (1 - self.ddim_alphas / self.ddim_alphas_prev)
|
||||||
|
)
|
||||||
|
** 0.5
|
||||||
|
)
|
||||||
|
|
||||||
# Flip all
|
# Flip all
|
||||||
self.ddim_t = torch.flip(self.ddim_t, [0])
|
self.ddim_t = torch.flip(self.ddim_t, [0])
|
||||||
self.ddim_alphas = torch.flip(self.ddim_alphas, [0])
|
self.ddim_alphas = torch.flip(self.ddim_alphas, [0])
|
||||||
self.ddim_alphas_sqrt = torch.flip(self.ddim_alphas_sqrt, [0])
|
self.ddim_alphas_sqrt = torch.flip(self.ddim_alphas_sqrt, [0])
|
||||||
self.ddim_alphas_prev = torch.flip(self.ddim_alphas_prev, [0])
|
self.ddim_alphas_prev = torch.flip(self.ddim_alphas_prev, [0])
|
||||||
self.ddim_sqrt_one_minus_alphas = torch.flip(self.ddim_sqrt_one_minus_alphas, [0])
|
self.ddim_sqrt_one_minus_alphas = torch.flip(
|
||||||
|
self.ddim_sqrt_one_minus_alphas, [0]
|
||||||
|
)
|
||||||
self.ddim_sigmas = torch.flip(self.ddim_sigmas, [0])
|
self.ddim_sigmas = torch.flip(self.ddim_sigmas, [0])
|
||||||
|
|
||||||
# ---------- Sampling ----------#
|
# ---------- Sampling ----------#
|
||||||
@ -183,8 +211,10 @@ class DiffusionModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
alpha = extract(self.ddim_alphas, index, x.shape)
|
alpha = extract(self.ddim_alphas, index, x.shape)
|
||||||
alpha_prev = extract(self.ddim_alphas_prev, index, x.shape)
|
alpha_prev = extract(self.ddim_alphas_prev, index, x.shape)
|
||||||
sqrt_one_minus_alpha = extract(self.ddim_sqrt_one_minus_alphas, index, x.shape)
|
sqrt_one_minus_alpha = extract(
|
||||||
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha ** 0.5)
|
self.ddim_sqrt_one_minus_alphas, index, x.shape
|
||||||
|
)
|
||||||
|
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha**0.5)
|
||||||
else:
|
else:
|
||||||
"""
|
"""
|
||||||
x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε
|
x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε
|
||||||
@ -213,9 +243,9 @@ class DiffusionModel(nn.Module):
|
|||||||
eta=0
|
eta=0
|
||||||
"""
|
"""
|
||||||
sigma = extract(self.ddim_sigmas, index, x.shape)
|
sigma = extract(self.ddim_sigmas, index, x.shape)
|
||||||
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * noise
|
dir_xt = (1.0 - alpha_prev - sigma**2).sqrt() * noise
|
||||||
mu = (alpha_prev ** 0.5) * x_recon + dir_xt
|
mu = (alpha_prev**0.5) * x_recon + dir_xt
|
||||||
var = sigma ** 2
|
var = sigma**2
|
||||||
logvar = torch.log(var)
|
logvar = torch.log(var)
|
||||||
else:
|
else:
|
||||||
"""
|
"""
|
||||||
@ -225,9 +255,7 @@ class DiffusionModel(nn.Module):
|
|||||||
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
|
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
|
||||||
+ extract(self.ddpm_mu_coef2, t, x.shape) * x
|
+ extract(self.ddpm_mu_coef2, t, x.shape) * x
|
||||||
)
|
)
|
||||||
logvar = extract(
|
logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
|
||||||
self.ddpm_logvar_clipped, t, x.shape
|
|
||||||
)
|
|
||||||
return mu, logvar
|
return mu, logvar
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -279,7 +307,9 @@ class DiffusionModel(nn.Module):
|
|||||||
|
|
||||||
# clamp action at final step
|
# clamp action at final step
|
||||||
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
||||||
x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value)
|
x = torch.clamp(
|
||||||
|
x, -self.final_action_clip_value, self.final_action_clip_value
|
||||||
|
)
|
||||||
return Sample(x, None)
|
return Sample(x, None)
|
||||||
|
|
||||||
# ---------- Supervised training ----------#
|
# ---------- Supervised training ----------#
|
||||||
@ -316,7 +346,7 @@ class DiffusionModel(nn.Module):
|
|||||||
if self.predict_epsilon:
|
if self.predict_epsilon:
|
||||||
return F.mse_loss(x_recon, noise, reduction="mean")
|
return F.mse_loss(x_recon, noise, reduction="mean")
|
||||||
else:
|
else:
|
||||||
return F.mse_loss(x_recon, x_noisy, reduction="mean")
|
return F.mse_loss(x_recon, x_start, reduction="mean")
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
def q_sample(self, x_start, t, noise=None):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user