add minor docs to diffusion classes and clean up some args

This commit is contained in:
allenzren 2024-09-17 16:26:25 -04:00
parent ef5b14f820
commit bc52beca1e
8 changed files with 33 additions and 98 deletions

View File

@ -18,6 +18,7 @@ class GaussianModel(torch.nn.Module):
horizon_steps,
network_path=None,
device="cuda:0",
randn_clip_value=10,
):
super().__init__()
self.device = device
@ -36,6 +37,9 @@ class GaussianModel(torch.nn.Module):
)
self.horizon_steps = horizon_steps
# Clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
self.randn_clip_value = randn_clip_value
def loss(
self,
true_action,
@ -75,7 +79,6 @@ class GaussianModel(torch.nn.Module):
self,
cond,
deterministic=False,
randn_clip_value=10,
network_override=None,
):
B = len(cond["state"]) if "state" in cond else len(cond["rgb"])
@ -87,7 +90,7 @@ class GaussianModel(torch.nn.Module):
)
sampled_action = dist.sample()
sampled_action.clamp_(
dist.loc - randn_clip_value * dist.scale,
dist.loc + randn_clip_value * dist.scale,
dist.loc - self.randn_clip_value * dist.scale,
dist.loc + self.randn_clip_value * dist.scale,
)
return sampled_action.view(B, T, -1)

View File

@ -19,8 +19,7 @@ class DIPODiffusion(DiffusionModel):
actor,
critic,
use_ddim=False,
randn_clip_value=10,
clamp_action=False,
# modifying denoising schedule
min_sampling_denoising_std=0.1,
**kwargs,
):
@ -34,12 +33,6 @@ class DIPODiffusion(DiffusionModel):
# Minimum std used in denoising process when sampling action - helps exploration
self.min_sampling_denoising_std = min_sampling_denoising_std
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
self.randn_clip_value = randn_clip_value
# Whether to clamp sampled action between [-1, 1]
self.clamp_action = clamp_action
# ---------- RL training ----------#
def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma):
@ -110,6 +103,8 @@ class DIPODiffusion(DiffusionModel):
x = mean + std * noise
# clamp action at final step
if self.clamp_action and i == len(t_all) - 1:
x = torch.clamp(x, -1, 1)
if self.final_action_clip_value is not None and i == len(t_all) - 1:
x = torch.clamp(
x, -self.final_action_clip_value, self.final_action_clip_value
)
return x

View File

@ -20,8 +20,7 @@ class DQLDiffusion(DiffusionModel):
actor,
critic,
use_ddim=False,
randn_clip_value=10,
clamp_action=False,
# modifying denoising schedule
min_sampling_denoising_std=0.1,
**kwargs,
):
@ -35,12 +34,6 @@ class DQLDiffusion(DiffusionModel):
# Minimum std used in denoising process when sampling action - helps exploration
self.min_sampling_denoising_std = min_sampling_denoising_std
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
self.randn_clip_value = randn_clip_value
# Whether to clamp sampled action between [-1, 1]
self.clamp_action = clamp_action
# ---------- RL training ----------#
def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma):
@ -110,7 +103,7 @@ class DQLDiffusion(DiffusionModel):
# Determine the noise level
if deterministic and t == 0:
std = torch.zeros_like(std)
elif deterministic: # For DDPM, sample with noise
elif deterministic:
std = torch.clip(std, min=1e-3)
else:
std = torch.clip(std, min=self.min_sampling_denoising_std)
@ -120,8 +113,10 @@ class DQLDiffusion(DiffusionModel):
x = mean + std * noise
# clamp action at final step
if self.clamp_action and i == len(t_all) - 1:
x = torch.clamp(x, -1, 1)
if self.final_action_clip_value is not None and i == len(t_all) - 1:
x = torch.clamp(
x, -self.final_action_clip_value, self.final_action_clip_value
)
return x
def forward_train(
@ -160,6 +155,6 @@ class DQLDiffusion(DiffusionModel):
x = mean + std * noise
# clamp action at final step
if self.clamp_action and i == len(t_all) - 1:
if self.final_action_clip_value and i == len(t_all) - 1:
x = torch.clamp(x, -1, 1)
return x

View File

@ -104,6 +104,7 @@ class IDQLDiffusion(RWRDiffusion):
cond,
t,
):
"""not reward-weighted, same as diffusion.py"""
device = x_start.device
# Forward process

View File

@ -11,7 +11,7 @@ log = logging.getLogger(__name__)
import torch.nn.functional as F
from model.diffusion.diffusion import DiffusionModel
from model.diffusion.sampling import make_timesteps, extract
from model.diffusion.sampling import make_timesteps
class RWRDiffusion(DiffusionModel):
@ -19,21 +19,13 @@ class RWRDiffusion(DiffusionModel):
def __init__(
self,
use_ddim=False,
# various clipping
randn_clip_value=10,
clamp_action=None,
# modifying denoising schedule
min_sampling_denoising_std=0.1,
**kwargs,
):
super().__init__(use_ddim=use_ddim, **kwargs)
assert not self.use_ddim, "RWR does not support DDIM"
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
self.randn_clip_value = randn_clip_value
# Action clamp range
self.clamp_action = clamp_action
# Minimum std used in denoising process when sampling action - helps exploration
self.min_sampling_denoising_std = min_sampling_denoising_std
@ -47,6 +39,7 @@ class RWRDiffusion(DiffusionModel):
rewards,
t,
):
"""reward-weighted"""
device = x_start.device
# Forward process
@ -67,40 +60,6 @@ class RWRDiffusion(DiffusionModel):
# ---------- Sampling ----------#
# override
def p_mean_var(
self,
x,
t,
cond,
):
noise = self.network(x, t, cond=cond)
# Predict x_0
if self.predict_epsilon:
"""
x₀ = 1\α̅ xₜ - 1\α̅-1 ε
"""
x_recon = (
extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise
)
else: # directly predicting x₀
x_recon = noise
if self.denoised_clip_value is not None:
x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value)
# Get mu
"""
μₜ = β̃ α̅/(1-α̅)x₀ + αₜ (1-α̅)/(1-α̅)xₜ
"""
mu = (
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
+ extract(self.ddpm_mu_coef2, t, x.shape) * x
)
logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
return mu, logvar
# override
@torch.no_grad()
def forward(
@ -108,6 +67,7 @@ class RWRDiffusion(DiffusionModel):
cond,
deterministic=False,
):
"""Modifying denoising schedule"""
device = self.betas.device
B = len(cond["state"])
@ -126,7 +86,7 @@ class RWRDiffusion(DiffusionModel):
# Determine noise level
if deterministic and t == 0:
std = torch.zeros_like(std)
elif deterministic: # For DDPM, sample with noise
elif deterministic:
std = torch.clip(std, min=1e-3)
else:
std = torch.clip(std, min=self.min_sampling_denoising_std)
@ -136,6 +96,8 @@ class RWRDiffusion(DiffusionModel):
x = mean + std * noise
# clamp action at final step
if self.clamp_action is not None and i == len(t_all) - 1:
x = torch.clamp(x, -self.clamp_action, self.clamp_action)
if self.final_action_clip_value is not None and i == len(t_all) - 1:
x = torch.clamp(
x, -self.final_action_clip_value, self.final_action_clip_value
)
return x

View File

@ -34,13 +34,10 @@ class VPGDiffusion(DiffusionModel):
ft_denoising_steps_d=0,
ft_denoising_steps_t=0,
network_path=None,
# various clipping
randn_clip_value=10,
clamp_action=False,
# modifying denoising schedule
min_sampling_denoising_std=0.1,
min_logprob_denoising_std=0.1, # or the scheduler class
eps_clip_value=None, # only used with DDIM
# DDIM related
min_logprob_denoising_std=0.1,
# eta in DDIM
eta=None,
learn_eta=False,
**kwargs,
@ -60,15 +57,6 @@ class VPGDiffusion(DiffusionModel):
self.ft_denoising_steps_t = ft_denoising_steps_t # annealing interval
self.ft_denoising_steps_cnt = 0
# Clip noise for numerical stability in policy gradient
self.eps_clip_value = eps_clip_value
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
self.randn_clip_value = randn_clip_value
# Whether to clamp sampled action between [-1, 1]
self.clamp_action = clamp_action
# Minimum std used in denoising process when sampling action - helps exploration
self.min_sampling_denoising_std = min_sampling_denoising_std
@ -214,7 +202,6 @@ class VPGDiffusion(DiffusionModel):
if deterministic:
etas = torch.zeros((x.shape[0], 1, 1)).to(x.device)
else:
# TODO: eta cond
etas = self.eta(cond).unsqueeze(1) # B x 1 x (Da or 1)
sigma = (
etas
@ -310,8 +297,8 @@ class VPGDiffusion(DiffusionModel):
x = mean + std * noise
# clamp action at final step
if self.clamp_action and i == len(t_all) - 1:
x = torch.clamp(x, -1, 1)
if self.final_action_clip_value is not None and i == len(t_all) - 1:
x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value)
if return_chain:
if not self.use_ddim and t <= self.ft_denoising_steps:

View File

@ -16,7 +16,6 @@ class RWR_Gaussian(GaussianModel):
def __init__(
self,
actor,
randn_clip_value=10,
**kwargs,
):
super().__init__(network=actor, **kwargs)
@ -24,9 +23,6 @@ class RWR_Gaussian(GaussianModel):
# assign actor
self.actor = self.network
# Clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
self.randn_clip_value = randn_clip_value
# override
def loss(self, actions, obs, reward_weights):
B = len(obs)
@ -44,6 +40,5 @@ class RWR_Gaussian(GaussianModel):
actions = super().forward(
cond=cond,
deterministic=deterministic,
randn_clip_value=self.randn_clip_value,
)
return actions

View File

@ -15,11 +15,9 @@ class VPG_Gaussian(GaussianModel):
self,
actor,
critic,
randn_clip_value=10,
**kwargs,
):
super().__init__(network=actor, **kwargs)
self.randn_clip_value = randn_clip_value
# Value function for obs - simple MLP
self.critic = critic.to(self.device)
@ -44,7 +42,6 @@ class VPG_Gaussian(GaussianModel):
return super().forward(
cond=cond,
deterministic=deterministic,
randn_clip_value=self.randn_clip_value,
network_override=self.actor if use_base_policy else None,
)