From bc52beca1ef632a3db1c083f7323e1f0be759f28 Mon Sep 17 00:00:00 2001 From: allenzren Date: Tue, 17 Sep 2024 16:26:25 -0400 Subject: [PATCH] add minor docs to diffusion classes and clean up some args --- model/common/gaussian.py | 9 +++-- model/diffusion/diffusion_dipo.py | 15 +++------ model/diffusion/diffusion_dql.py | 19 ++++------- model/diffusion/diffusion_idql.py | 1 + model/diffusion/diffusion_rwr.py | 56 +++++-------------------------- model/diffusion/diffusion_vpg.py | 23 +++---------- model/rl/gaussian_rwr.py | 5 --- model/rl/gaussian_vpg.py | 3 -- 8 files changed, 33 insertions(+), 98 deletions(-) diff --git a/model/common/gaussian.py b/model/common/gaussian.py index 42bbc9e..3fea827 100644 --- a/model/common/gaussian.py +++ b/model/common/gaussian.py @@ -18,6 +18,7 @@ class GaussianModel(torch.nn.Module): horizon_steps, network_path=None, device="cuda:0", + randn_clip_value=10, ): super().__init__() self.device = device @@ -36,6 +37,9 @@ class GaussianModel(torch.nn.Module): ) self.horizon_steps = horizon_steps + # Clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean + self.randn_clip_value = randn_clip_value + def loss( self, true_action, @@ -75,7 +79,6 @@ class GaussianModel(torch.nn.Module): self, cond, deterministic=False, - randn_clip_value=10, network_override=None, ): B = len(cond["state"]) if "state" in cond else len(cond["rgb"]) @@ -87,7 +90,7 @@ class GaussianModel(torch.nn.Module): ) sampled_action = dist.sample() sampled_action.clamp_( - dist.loc - randn_clip_value * dist.scale, - dist.loc + randn_clip_value * dist.scale, + dist.loc - self.randn_clip_value * dist.scale, + dist.loc + self.randn_clip_value * dist.scale, ) return sampled_action.view(B, T, -1) diff --git a/model/diffusion/diffusion_dipo.py b/model/diffusion/diffusion_dipo.py index a051ff4..e4459dd 100644 --- a/model/diffusion/diffusion_dipo.py +++ b/model/diffusion/diffusion_dipo.py @@ -19,8 +19,7 @@ class DIPODiffusion(DiffusionModel): actor, critic, use_ddim=False, - randn_clip_value=10, - clamp_action=False, + # modifying denoising schedule min_sampling_denoising_std=0.1, **kwargs, ): @@ -34,12 +33,6 @@ class DIPODiffusion(DiffusionModel): # Minimum std used in denoising process when sampling action - helps exploration self.min_sampling_denoising_std = min_sampling_denoising_std - # For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean - self.randn_clip_value = randn_clip_value - - # Whether to clamp sampled action between [-1, 1] - self.clamp_action = clamp_action - # ---------- RL training ----------# def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma): @@ -110,6 +103,8 @@ class DIPODiffusion(DiffusionModel): x = mean + std * noise # clamp action at final step - if self.clamp_action and i == len(t_all) - 1: - x = torch.clamp(x, -1, 1) + if self.final_action_clip_value is not None and i == len(t_all) - 1: + x = torch.clamp( + x, -self.final_action_clip_value, self.final_action_clip_value + ) return x diff --git a/model/diffusion/diffusion_dql.py b/model/diffusion/diffusion_dql.py index 4b53cc2..f8c315c 100644 --- a/model/diffusion/diffusion_dql.py +++ b/model/diffusion/diffusion_dql.py @@ -20,8 +20,7 @@ class DQLDiffusion(DiffusionModel): actor, critic, use_ddim=False, - randn_clip_value=10, - clamp_action=False, + # modifying denoising schedule min_sampling_denoising_std=0.1, **kwargs, ): @@ -35,12 +34,6 @@ class DQLDiffusion(DiffusionModel): # Minimum std used in denoising process when sampling action - helps exploration self.min_sampling_denoising_std = min_sampling_denoising_std - # For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean - self.randn_clip_value = randn_clip_value - - # Whether to clamp sampled action between [-1, 1] - self.clamp_action = clamp_action - # ---------- RL training ----------# def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma): @@ -110,7 +103,7 @@ class DQLDiffusion(DiffusionModel): # Determine the noise level if deterministic and t == 0: std = torch.zeros_like(std) - elif deterministic: # For DDPM, sample with noise + elif deterministic: std = torch.clip(std, min=1e-3) else: std = torch.clip(std, min=self.min_sampling_denoising_std) @@ -120,8 +113,10 @@ class DQLDiffusion(DiffusionModel): x = mean + std * noise # clamp action at final step - if self.clamp_action and i == len(t_all) - 1: - x = torch.clamp(x, -1, 1) + if self.final_action_clip_value is not None and i == len(t_all) - 1: + x = torch.clamp( + x, -self.final_action_clip_value, self.final_action_clip_value + ) return x def forward_train( @@ -160,6 +155,6 @@ class DQLDiffusion(DiffusionModel): x = mean + std * noise # clamp action at final step - if self.clamp_action and i == len(t_all) - 1: + if self.final_action_clip_value and i == len(t_all) - 1: x = torch.clamp(x, -1, 1) return x diff --git a/model/diffusion/diffusion_idql.py b/model/diffusion/diffusion_idql.py index 8a5b917..3b8fd28 100644 --- a/model/diffusion/diffusion_idql.py +++ b/model/diffusion/diffusion_idql.py @@ -104,6 +104,7 @@ class IDQLDiffusion(RWRDiffusion): cond, t, ): + """not reward-weighted, same as diffusion.py""" device = x_start.device # Forward process diff --git a/model/diffusion/diffusion_rwr.py b/model/diffusion/diffusion_rwr.py index 8bad230..c08b442 100644 --- a/model/diffusion/diffusion_rwr.py +++ b/model/diffusion/diffusion_rwr.py @@ -11,7 +11,7 @@ log = logging.getLogger(__name__) import torch.nn.functional as F from model.diffusion.diffusion import DiffusionModel -from model.diffusion.sampling import make_timesteps, extract +from model.diffusion.sampling import make_timesteps class RWRDiffusion(DiffusionModel): @@ -19,21 +19,13 @@ class RWRDiffusion(DiffusionModel): def __init__( self, use_ddim=False, - # various clipping - randn_clip_value=10, - clamp_action=None, + # modifying denoising schedule min_sampling_denoising_std=0.1, **kwargs, ): super().__init__(use_ddim=use_ddim, **kwargs) assert not self.use_ddim, "RWR does not support DDIM" - # For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean - self.randn_clip_value = randn_clip_value - - # Action clamp range - self.clamp_action = clamp_action - # Minimum std used in denoising process when sampling action - helps exploration self.min_sampling_denoising_std = min_sampling_denoising_std @@ -47,6 +39,7 @@ class RWRDiffusion(DiffusionModel): rewards, t, ): + """reward-weighted""" device = x_start.device # Forward process @@ -67,40 +60,6 @@ class RWRDiffusion(DiffusionModel): # ---------- Sampling ----------# - # override - def p_mean_var( - self, - x, - t, - cond, - ): - noise = self.network(x, t, cond=cond) - - # Predict x_0 - if self.predict_epsilon: - """ - x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε - """ - x_recon = ( - extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - - extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise - ) - else: # directly predicting x₀ - x_recon = noise - if self.denoised_clip_value is not None: - x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value) - - # Get mu - """ - μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ - """ - mu = ( - extract(self.ddpm_mu_coef1, t, x.shape) * x_recon - + extract(self.ddpm_mu_coef2, t, x.shape) * x - ) - logvar = extract(self.ddpm_logvar_clipped, t, x.shape) - return mu, logvar - # override @torch.no_grad() def forward( @@ -108,6 +67,7 @@ class RWRDiffusion(DiffusionModel): cond, deterministic=False, ): + """Modifying denoising schedule""" device = self.betas.device B = len(cond["state"]) @@ -126,7 +86,7 @@ class RWRDiffusion(DiffusionModel): # Determine noise level if deterministic and t == 0: std = torch.zeros_like(std) - elif deterministic: # For DDPM, sample with noise + elif deterministic: std = torch.clip(std, min=1e-3) else: std = torch.clip(std, min=self.min_sampling_denoising_std) @@ -136,6 +96,8 @@ class RWRDiffusion(DiffusionModel): x = mean + std * noise # clamp action at final step - if self.clamp_action is not None and i == len(t_all) - 1: - x = torch.clamp(x, -self.clamp_action, self.clamp_action) + if self.final_action_clip_value is not None and i == len(t_all) - 1: + x = torch.clamp( + x, -self.final_action_clip_value, self.final_action_clip_value + ) return x diff --git a/model/diffusion/diffusion_vpg.py b/model/diffusion/diffusion_vpg.py index 809148c..de7f5e7 100644 --- a/model/diffusion/diffusion_vpg.py +++ b/model/diffusion/diffusion_vpg.py @@ -34,13 +34,10 @@ class VPGDiffusion(DiffusionModel): ft_denoising_steps_d=0, ft_denoising_steps_t=0, network_path=None, - # various clipping - randn_clip_value=10, - clamp_action=False, + # modifying denoising schedule min_sampling_denoising_std=0.1, - min_logprob_denoising_std=0.1, # or the scheduler class - eps_clip_value=None, # only used with DDIM - # DDIM related + min_logprob_denoising_std=0.1, + # eta in DDIM eta=None, learn_eta=False, **kwargs, @@ -60,15 +57,6 @@ class VPGDiffusion(DiffusionModel): self.ft_denoising_steps_t = ft_denoising_steps_t # annealing interval self.ft_denoising_steps_cnt = 0 - # Clip noise for numerical stability in policy gradient - self.eps_clip_value = eps_clip_value - - # For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean - self.randn_clip_value = randn_clip_value - - # Whether to clamp sampled action between [-1, 1] - self.clamp_action = clamp_action - # Minimum std used in denoising process when sampling action - helps exploration self.min_sampling_denoising_std = min_sampling_denoising_std @@ -214,7 +202,6 @@ class VPGDiffusion(DiffusionModel): if deterministic: etas = torch.zeros((x.shape[0], 1, 1)).to(x.device) else: - # TODO: eta cond etas = self.eta(cond).unsqueeze(1) # B x 1 x (Da or 1) sigma = ( etas @@ -310,8 +297,8 @@ class VPGDiffusion(DiffusionModel): x = mean + std * noise # clamp action at final step - if self.clamp_action and i == len(t_all) - 1: - x = torch.clamp(x, -1, 1) + if self.final_action_clip_value is not None and i == len(t_all) - 1: + x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value) if return_chain: if not self.use_ddim and t <= self.ft_denoising_steps: diff --git a/model/rl/gaussian_rwr.py b/model/rl/gaussian_rwr.py index 2d0a1ad..55e2314 100644 --- a/model/rl/gaussian_rwr.py +++ b/model/rl/gaussian_rwr.py @@ -16,7 +16,6 @@ class RWR_Gaussian(GaussianModel): def __init__( self, actor, - randn_clip_value=10, **kwargs, ): super().__init__(network=actor, **kwargs) @@ -24,9 +23,6 @@ class RWR_Gaussian(GaussianModel): # assign actor self.actor = self.network - # Clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean - self.randn_clip_value = randn_clip_value - # override def loss(self, actions, obs, reward_weights): B = len(obs) @@ -44,6 +40,5 @@ class RWR_Gaussian(GaussianModel): actions = super().forward( cond=cond, deterministic=deterministic, - randn_clip_value=self.randn_clip_value, ) return actions diff --git a/model/rl/gaussian_vpg.py b/model/rl/gaussian_vpg.py index 34b6d98..d503192 100644 --- a/model/rl/gaussian_vpg.py +++ b/model/rl/gaussian_vpg.py @@ -15,11 +15,9 @@ class VPG_Gaussian(GaussianModel): self, actor, critic, - randn_clip_value=10, **kwargs, ): super().__init__(network=actor, **kwargs) - self.randn_clip_value = randn_clip_value # Value function for obs - simple MLP self.critic = critic.to(self.device) @@ -44,7 +42,6 @@ class VPG_Gaussian(GaussianModel): return super().forward( cond=cond, deterministic=deterministic, - randn_clip_value=self.randn_clip_value, network_override=self.actor if use_base_policy else None, )