diff --git a/model/diffusion/diffusion.py b/model/diffusion/diffusion.py index 9d08c30..f2ad1aa 100644 --- a/model/diffusion/diffusion.py +++ b/model/diffusion/diffusion.py @@ -22,6 +22,7 @@ from model.diffusion.sampling import ( ) from collections import namedtuple + Sample = namedtuple("Sample", "trajectories chains") @@ -45,7 +46,7 @@ class DiffusionModel(nn.Module): predict_epsilon=True, # DDIM sampling use_ddim=False, - ddim_discretize='uniform', + ddim_discretize="uniform", ddim_steps=None, **kwargs, ): @@ -74,7 +75,9 @@ class DiffusionModel(nn.Module): # Set up models self.network = network.to(device) if network_path is not None: - checkpoint = torch.load(network_path, map_location=device, weights_only=True) + checkpoint = torch.load( + network_path, map_location=device, weights_only=True + ) if "ema" in checkpoint: self.load_state_dict(checkpoint["ema"], strict=False) logging.info("Loaded SL-trained policy from %s", network_path) @@ -104,7 +107,9 @@ class DiffusionModel(nn.Module): """ α̅ₜ₋₁ """ - self.alphas_cumprod_prev = torch.cat([torch.ones(1).to(device), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_prev = torch.cat( + [torch.ones(1).to(device), self.alphas_cumprod[:-1]] + ) """ √ α̅ₜ """ @@ -131,8 +136,16 @@ class DiffusionModel(nn.Module): """ μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ """ - self.ddpm_mu_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - self.ddpm_mu_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod) + self.ddpm_mu_coef1 = ( + self.betas + * torch.sqrt(self.alphas_cumprod_prev) + / (1.0 - self.alphas_cumprod) + ) + self.ddpm_mu_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * torch.sqrt(self.alphas) + / (1.0 - self.alphas_cumprod) + ) """ DDIM parameters @@ -141,30 +154,45 @@ class DiffusionModel(nn.Module): """ if use_ddim: assert predict_epsilon, "DDIM requires predicting epsilon for now." - if ddim_discretize == 'uniform': # use the HF "leading" style + if ddim_discretize == "uniform": # use the HF "leading" style step_ratio = self.denoising_steps // ddim_steps - self.ddim_t = torch.arange(0, ddim_steps, device=self.device) * step_ratio + self.ddim_t = ( + torch.arange(0, ddim_steps, device=self.device) * step_ratio + ) else: - raise 'Unknown discretization method for DDIM.' - self.ddim_alphas = self.alphas_cumprod[self.ddim_t].clone().to(torch.float32) + raise "Unknown discretization method for DDIM." + self.ddim_alphas = ( + self.alphas_cumprod[self.ddim_t].clone().to(torch.float32) + ) self.ddim_alphas_sqrt = torch.sqrt(self.ddim_alphas) - self.ddim_alphas_prev = torch.cat([ - torch.tensor([1.]).to(torch.float32).to(self.device), - self.alphas_cumprod[self.ddim_t[:-1]]]) - self.ddim_sqrt_one_minus_alphas = (1. - self.ddim_alphas) ** .5 + self.ddim_alphas_prev = torch.cat( + [ + torch.tensor([1.0]).to(torch.float32).to(self.device), + self.alphas_cumprod[self.ddim_t[:-1]], + ] + ) + self.ddim_sqrt_one_minus_alphas = (1.0 - self.ddim_alphas) ** 0.5 # Initialize fixed sigmas for inference - eta=0 ddim_eta = 0 - self.ddim_sigmas = (ddim_eta * \ - ((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * \ - (1 - self.ddim_alphas / self.ddim_alphas_prev)) ** .5) + self.ddim_sigmas = ( + ddim_eta + * ( + (1 - self.ddim_alphas_prev) + / (1 - self.ddim_alphas) + * (1 - self.ddim_alphas / self.ddim_alphas_prev) + ) + ** 0.5 + ) # Flip all self.ddim_t = torch.flip(self.ddim_t, [0]) self.ddim_alphas = torch.flip(self.ddim_alphas, [0]) self.ddim_alphas_sqrt = torch.flip(self.ddim_alphas_sqrt, [0]) self.ddim_alphas_prev = torch.flip(self.ddim_alphas_prev, [0]) - self.ddim_sqrt_one_minus_alphas = torch.flip(self.ddim_sqrt_one_minus_alphas, [0]) + self.ddim_sqrt_one_minus_alphas = torch.flip( + self.ddim_sqrt_one_minus_alphas, [0] + ) self.ddim_sigmas = torch.flip(self.ddim_sigmas, [0]) # ---------- Sampling ----------# @@ -183,8 +211,10 @@ class DiffusionModel(nn.Module): """ alpha = extract(self.ddim_alphas, index, x.shape) alpha_prev = extract(self.ddim_alphas_prev, index, x.shape) - sqrt_one_minus_alpha = extract(self.ddim_sqrt_one_minus_alphas, index, x.shape) - x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha ** 0.5) + sqrt_one_minus_alpha = extract( + self.ddim_sqrt_one_minus_alphas, index, x.shape + ) + x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha**0.5) else: """ x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε @@ -193,7 +223,7 @@ class DiffusionModel(nn.Module): extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise ) - else: # directly predicting x₀ + else: # directly predicting x₀ x_recon = noise if self.denoised_clip_value is not None: x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value) @@ -208,14 +238,14 @@ class DiffusionModel(nn.Module): # Get mu if self.use_ddim: """ - μ = √ αₜ₋₁ x₀ + √(1-αₜ₋₁ - σₜ²) ε - + μ = √ αₜ₋₁ x₀ + √(1-αₜ₋₁ - σₜ²) ε + eta=0 """ sigma = extract(self.ddim_sigmas, index, x.shape) - dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * noise - mu = (alpha_prev ** 0.5) * x_recon + dir_xt - var = sigma ** 2 + dir_xt = (1.0 - alpha_prev - sigma**2).sqrt() * noise + mu = (alpha_prev**0.5) * x_recon + dir_xt + var = sigma**2 logvar = torch.log(var) else: """ @@ -225,9 +255,7 @@ class DiffusionModel(nn.Module): extract(self.ddpm_mu_coef1, t, x.shape) * x_recon + extract(self.ddpm_mu_coef2, t, x.shape) * x ) - logvar = extract( - self.ddpm_logvar_clipped, t, x.shape - ) + logvar = extract(self.ddpm_logvar_clipped, t, x.shape) return mu, logvar @torch.no_grad() @@ -279,7 +307,9 @@ class DiffusionModel(nn.Module): # clamp action at final step if self.final_action_clip_value is not None and i == len(t_all) - 1: - x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value) + x = torch.clamp( + x, -self.final_action_clip_value, self.final_action_clip_value + ) return Sample(x, None) # ---------- Supervised training ----------# @@ -314,9 +344,9 @@ class DiffusionModel(nn.Module): # Predict x_recon = self.network(x_noisy, t, cond=cond) if self.predict_epsilon: - return F.mse_loss(x_recon, noise, reduction="mean") + return F.mse_loss(x_recon, noise, reduction="mean") else: - return F.mse_loss(x_recon, x_noisy, reduction="mean") + return F.mse_loss(x_recon, x_start, reduction="mean") def q_sample(self, x_start, t, noise=None): """