fix diffusion loss when predicting initial noise

This commit is contained in:
allenzren 2024-10-13 11:19:10 -04:00
parent 4e14b8086d
commit 7b10df690d

View File

@ -22,6 +22,7 @@ from model.diffusion.sampling import (
) )
from collections import namedtuple from collections import namedtuple
Sample = namedtuple("Sample", "trajectories chains") Sample = namedtuple("Sample", "trajectories chains")
@ -45,7 +46,7 @@ class DiffusionModel(nn.Module):
predict_epsilon=True, predict_epsilon=True,
# DDIM sampling # DDIM sampling
use_ddim=False, use_ddim=False,
ddim_discretize='uniform', ddim_discretize="uniform",
ddim_steps=None, ddim_steps=None,
**kwargs, **kwargs,
): ):
@ -74,7 +75,9 @@ class DiffusionModel(nn.Module):
# Set up models # Set up models
self.network = network.to(device) self.network = network.to(device)
if network_path is not None: if network_path is not None:
checkpoint = torch.load(network_path, map_location=device, weights_only=True) checkpoint = torch.load(
network_path, map_location=device, weights_only=True
)
if "ema" in checkpoint: if "ema" in checkpoint:
self.load_state_dict(checkpoint["ema"], strict=False) self.load_state_dict(checkpoint["ema"], strict=False)
logging.info("Loaded SL-trained policy from %s", network_path) logging.info("Loaded SL-trained policy from %s", network_path)
@ -104,7 +107,9 @@ class DiffusionModel(nn.Module):
""" """
α̅ α̅
""" """
self.alphas_cumprod_prev = torch.cat([torch.ones(1).to(device), self.alphas_cumprod[:-1]]) self.alphas_cumprod_prev = torch.cat(
[torch.ones(1).to(device), self.alphas_cumprod[:-1]]
)
""" """
α̅ α̅
""" """
@ -131,8 +136,16 @@ class DiffusionModel(nn.Module):
""" """
μₜ = β̃ α̅/(1-α̅)x₀ + αₜ (1-α̅)/(1-α̅)xₜ μₜ = β̃ α̅/(1-α̅)x₀ + αₜ (1-α̅)/(1-α̅)xₜ
""" """
self.ddpm_mu_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) self.ddpm_mu_coef1 = (
self.ddpm_mu_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod) self.betas
* torch.sqrt(self.alphas_cumprod_prev)
/ (1.0 - self.alphas_cumprod)
)
self.ddpm_mu_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* torch.sqrt(self.alphas)
/ (1.0 - self.alphas_cumprod)
)
""" """
DDIM parameters DDIM parameters
@ -141,30 +154,45 @@ class DiffusionModel(nn.Module):
""" """
if use_ddim: if use_ddim:
assert predict_epsilon, "DDIM requires predicting epsilon for now." assert predict_epsilon, "DDIM requires predicting epsilon for now."
if ddim_discretize == 'uniform': # use the HF "leading" style if ddim_discretize == "uniform": # use the HF "leading" style
step_ratio = self.denoising_steps // ddim_steps step_ratio = self.denoising_steps // ddim_steps
self.ddim_t = torch.arange(0, ddim_steps, device=self.device) * step_ratio self.ddim_t = (
torch.arange(0, ddim_steps, device=self.device) * step_ratio
)
else: else:
raise 'Unknown discretization method for DDIM.' raise "Unknown discretization method for DDIM."
self.ddim_alphas = self.alphas_cumprod[self.ddim_t].clone().to(torch.float32) self.ddim_alphas = (
self.alphas_cumprod[self.ddim_t].clone().to(torch.float32)
)
self.ddim_alphas_sqrt = torch.sqrt(self.ddim_alphas) self.ddim_alphas_sqrt = torch.sqrt(self.ddim_alphas)
self.ddim_alphas_prev = torch.cat([ self.ddim_alphas_prev = torch.cat(
torch.tensor([1.]).to(torch.float32).to(self.device), [
self.alphas_cumprod[self.ddim_t[:-1]]]) torch.tensor([1.0]).to(torch.float32).to(self.device),
self.ddim_sqrt_one_minus_alphas = (1. - self.ddim_alphas) ** .5 self.alphas_cumprod[self.ddim_t[:-1]],
]
)
self.ddim_sqrt_one_minus_alphas = (1.0 - self.ddim_alphas) ** 0.5
# Initialize fixed sigmas for inference - eta=0 # Initialize fixed sigmas for inference - eta=0
ddim_eta = 0 ddim_eta = 0
self.ddim_sigmas = (ddim_eta * \ self.ddim_sigmas = (
((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * \ ddim_eta
(1 - self.ddim_alphas / self.ddim_alphas_prev)) ** .5) * (
(1 - self.ddim_alphas_prev)
/ (1 - self.ddim_alphas)
* (1 - self.ddim_alphas / self.ddim_alphas_prev)
)
** 0.5
)
# Flip all # Flip all
self.ddim_t = torch.flip(self.ddim_t, [0]) self.ddim_t = torch.flip(self.ddim_t, [0])
self.ddim_alphas = torch.flip(self.ddim_alphas, [0]) self.ddim_alphas = torch.flip(self.ddim_alphas, [0])
self.ddim_alphas_sqrt = torch.flip(self.ddim_alphas_sqrt, [0]) self.ddim_alphas_sqrt = torch.flip(self.ddim_alphas_sqrt, [0])
self.ddim_alphas_prev = torch.flip(self.ddim_alphas_prev, [0]) self.ddim_alphas_prev = torch.flip(self.ddim_alphas_prev, [0])
self.ddim_sqrt_one_minus_alphas = torch.flip(self.ddim_sqrt_one_minus_alphas, [0]) self.ddim_sqrt_one_minus_alphas = torch.flip(
self.ddim_sqrt_one_minus_alphas, [0]
)
self.ddim_sigmas = torch.flip(self.ddim_sigmas, [0]) self.ddim_sigmas = torch.flip(self.ddim_sigmas, [0])
# ---------- Sampling ----------# # ---------- Sampling ----------#
@ -183,8 +211,10 @@ class DiffusionModel(nn.Module):
""" """
alpha = extract(self.ddim_alphas, index, x.shape) alpha = extract(self.ddim_alphas, index, x.shape)
alpha_prev = extract(self.ddim_alphas_prev, index, x.shape) alpha_prev = extract(self.ddim_alphas_prev, index, x.shape)
sqrt_one_minus_alpha = extract(self.ddim_sqrt_one_minus_alphas, index, x.shape) sqrt_one_minus_alpha = extract(
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha ** 0.5) self.ddim_sqrt_one_minus_alphas, index, x.shape
)
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha**0.5)
else: else:
""" """
x₀ = 1\α̅ xₜ - 1\α̅-1 ε x₀ = 1\α̅ xₜ - 1\α̅-1 ε
@ -213,9 +243,9 @@ class DiffusionModel(nn.Module):
eta=0 eta=0
""" """
sigma = extract(self.ddim_sigmas, index, x.shape) sigma = extract(self.ddim_sigmas, index, x.shape)
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * noise dir_xt = (1.0 - alpha_prev - sigma**2).sqrt() * noise
mu = (alpha_prev ** 0.5) * x_recon + dir_xt mu = (alpha_prev**0.5) * x_recon + dir_xt
var = sigma ** 2 var = sigma**2
logvar = torch.log(var) logvar = torch.log(var)
else: else:
""" """
@ -225,9 +255,7 @@ class DiffusionModel(nn.Module):
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
+ extract(self.ddpm_mu_coef2, t, x.shape) * x + extract(self.ddpm_mu_coef2, t, x.shape) * x
) )
logvar = extract( logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
self.ddpm_logvar_clipped, t, x.shape
)
return mu, logvar return mu, logvar
@torch.no_grad() @torch.no_grad()
@ -279,7 +307,9 @@ class DiffusionModel(nn.Module):
# clamp action at final step # clamp action at final step
if self.final_action_clip_value is not None and i == len(t_all) - 1: if self.final_action_clip_value is not None and i == len(t_all) - 1:
x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value) x = torch.clamp(
x, -self.final_action_clip_value, self.final_action_clip_value
)
return Sample(x, None) return Sample(x, None)
# ---------- Supervised training ----------# # ---------- Supervised training ----------#
@ -316,7 +346,7 @@ class DiffusionModel(nn.Module):
if self.predict_epsilon: if self.predict_epsilon:
return F.mse_loss(x_recon, noise, reduction="mean") return F.mse_loss(x_recon, noise, reduction="mean")
else: else:
return F.mse_loss(x_recon, x_noisy, reduction="mean") return F.mse_loss(x_recon, x_start, reduction="mean")
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None):
""" """