add minor docs to diffusion classes and clean up some args
This commit is contained in:
parent
ef5b14f820
commit
bc52beca1e
@ -18,6 +18,7 @@ class GaussianModel(torch.nn.Module):
|
||||
horizon_steps,
|
||||
network_path=None,
|
||||
device="cuda:0",
|
||||
randn_clip_value=10,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
@ -36,6 +37,9 @@ class GaussianModel(torch.nn.Module):
|
||||
)
|
||||
self.horizon_steps = horizon_steps
|
||||
|
||||
# Clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
||||
self.randn_clip_value = randn_clip_value
|
||||
|
||||
def loss(
|
||||
self,
|
||||
true_action,
|
||||
@ -75,7 +79,6 @@ class GaussianModel(torch.nn.Module):
|
||||
self,
|
||||
cond,
|
||||
deterministic=False,
|
||||
randn_clip_value=10,
|
||||
network_override=None,
|
||||
):
|
||||
B = len(cond["state"]) if "state" in cond else len(cond["rgb"])
|
||||
@ -87,7 +90,7 @@ class GaussianModel(torch.nn.Module):
|
||||
)
|
||||
sampled_action = dist.sample()
|
||||
sampled_action.clamp_(
|
||||
dist.loc - randn_clip_value * dist.scale,
|
||||
dist.loc + randn_clip_value * dist.scale,
|
||||
dist.loc - self.randn_clip_value * dist.scale,
|
||||
dist.loc + self.randn_clip_value * dist.scale,
|
||||
)
|
||||
return sampled_action.view(B, T, -1)
|
||||
|
@ -19,8 +19,7 @@ class DIPODiffusion(DiffusionModel):
|
||||
actor,
|
||||
critic,
|
||||
use_ddim=False,
|
||||
randn_clip_value=10,
|
||||
clamp_action=False,
|
||||
# modifying denoising schedule
|
||||
min_sampling_denoising_std=0.1,
|
||||
**kwargs,
|
||||
):
|
||||
@ -34,12 +33,6 @@ class DIPODiffusion(DiffusionModel):
|
||||
# Minimum std used in denoising process when sampling action - helps exploration
|
||||
self.min_sampling_denoising_std = min_sampling_denoising_std
|
||||
|
||||
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
||||
self.randn_clip_value = randn_clip_value
|
||||
|
||||
# Whether to clamp sampled action between [-1, 1]
|
||||
self.clamp_action = clamp_action
|
||||
|
||||
# ---------- RL training ----------#
|
||||
|
||||
def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma):
|
||||
@ -110,6 +103,8 @@ class DIPODiffusion(DiffusionModel):
|
||||
x = mean + std * noise
|
||||
|
||||
# clamp action at final step
|
||||
if self.clamp_action and i == len(t_all) - 1:
|
||||
x = torch.clamp(x, -1, 1)
|
||||
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
||||
x = torch.clamp(
|
||||
x, -self.final_action_clip_value, self.final_action_clip_value
|
||||
)
|
||||
return x
|
||||
|
@ -20,8 +20,7 @@ class DQLDiffusion(DiffusionModel):
|
||||
actor,
|
||||
critic,
|
||||
use_ddim=False,
|
||||
randn_clip_value=10,
|
||||
clamp_action=False,
|
||||
# modifying denoising schedule
|
||||
min_sampling_denoising_std=0.1,
|
||||
**kwargs,
|
||||
):
|
||||
@ -35,12 +34,6 @@ class DQLDiffusion(DiffusionModel):
|
||||
# Minimum std used in denoising process when sampling action - helps exploration
|
||||
self.min_sampling_denoising_std = min_sampling_denoising_std
|
||||
|
||||
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
||||
self.randn_clip_value = randn_clip_value
|
||||
|
||||
# Whether to clamp sampled action between [-1, 1]
|
||||
self.clamp_action = clamp_action
|
||||
|
||||
# ---------- RL training ----------#
|
||||
|
||||
def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma):
|
||||
@ -110,7 +103,7 @@ class DQLDiffusion(DiffusionModel):
|
||||
# Determine the noise level
|
||||
if deterministic and t == 0:
|
||||
std = torch.zeros_like(std)
|
||||
elif deterministic: # For DDPM, sample with noise
|
||||
elif deterministic:
|
||||
std = torch.clip(std, min=1e-3)
|
||||
else:
|
||||
std = torch.clip(std, min=self.min_sampling_denoising_std)
|
||||
@ -120,8 +113,10 @@ class DQLDiffusion(DiffusionModel):
|
||||
x = mean + std * noise
|
||||
|
||||
# clamp action at final step
|
||||
if self.clamp_action and i == len(t_all) - 1:
|
||||
x = torch.clamp(x, -1, 1)
|
||||
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
||||
x = torch.clamp(
|
||||
x, -self.final_action_clip_value, self.final_action_clip_value
|
||||
)
|
||||
return x
|
||||
|
||||
def forward_train(
|
||||
@ -160,6 +155,6 @@ class DQLDiffusion(DiffusionModel):
|
||||
x = mean + std * noise
|
||||
|
||||
# clamp action at final step
|
||||
if self.clamp_action and i == len(t_all) - 1:
|
||||
if self.final_action_clip_value and i == len(t_all) - 1:
|
||||
x = torch.clamp(x, -1, 1)
|
||||
return x
|
||||
|
@ -104,6 +104,7 @@ class IDQLDiffusion(RWRDiffusion):
|
||||
cond,
|
||||
t,
|
||||
):
|
||||
"""not reward-weighted, same as diffusion.py"""
|
||||
device = x_start.device
|
||||
|
||||
# Forward process
|
||||
|
@ -11,7 +11,7 @@ log = logging.getLogger(__name__)
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.diffusion.diffusion import DiffusionModel
|
||||
from model.diffusion.sampling import make_timesteps, extract
|
||||
from model.diffusion.sampling import make_timesteps
|
||||
|
||||
|
||||
class RWRDiffusion(DiffusionModel):
|
||||
@ -19,21 +19,13 @@ class RWRDiffusion(DiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
use_ddim=False,
|
||||
# various clipping
|
||||
randn_clip_value=10,
|
||||
clamp_action=None,
|
||||
# modifying denoising schedule
|
||||
min_sampling_denoising_std=0.1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(use_ddim=use_ddim, **kwargs)
|
||||
assert not self.use_ddim, "RWR does not support DDIM"
|
||||
|
||||
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
||||
self.randn_clip_value = randn_clip_value
|
||||
|
||||
# Action clamp range
|
||||
self.clamp_action = clamp_action
|
||||
|
||||
# Minimum std used in denoising process when sampling action - helps exploration
|
||||
self.min_sampling_denoising_std = min_sampling_denoising_std
|
||||
|
||||
@ -47,6 +39,7 @@ class RWRDiffusion(DiffusionModel):
|
||||
rewards,
|
||||
t,
|
||||
):
|
||||
"""reward-weighted"""
|
||||
device = x_start.device
|
||||
|
||||
# Forward process
|
||||
@ -67,40 +60,6 @@ class RWRDiffusion(DiffusionModel):
|
||||
|
||||
# ---------- Sampling ----------#
|
||||
|
||||
# override
|
||||
def p_mean_var(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
cond,
|
||||
):
|
||||
noise = self.network(x, t, cond=cond)
|
||||
|
||||
# Predict x_0
|
||||
if self.predict_epsilon:
|
||||
"""
|
||||
x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε
|
||||
"""
|
||||
x_recon = (
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
||||
- extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise
|
||||
)
|
||||
else: # directly predicting x₀
|
||||
x_recon = noise
|
||||
if self.denoised_clip_value is not None:
|
||||
x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value)
|
||||
|
||||
# Get mu
|
||||
"""
|
||||
μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
|
||||
"""
|
||||
mu = (
|
||||
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
|
||||
+ extract(self.ddpm_mu_coef2, t, x.shape) * x
|
||||
)
|
||||
logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
|
||||
return mu, logvar
|
||||
|
||||
# override
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@ -108,6 +67,7 @@ class RWRDiffusion(DiffusionModel):
|
||||
cond,
|
||||
deterministic=False,
|
||||
):
|
||||
"""Modifying denoising schedule"""
|
||||
device = self.betas.device
|
||||
B = len(cond["state"])
|
||||
|
||||
@ -126,7 +86,7 @@ class RWRDiffusion(DiffusionModel):
|
||||
# Determine noise level
|
||||
if deterministic and t == 0:
|
||||
std = torch.zeros_like(std)
|
||||
elif deterministic: # For DDPM, sample with noise
|
||||
elif deterministic:
|
||||
std = torch.clip(std, min=1e-3)
|
||||
else:
|
||||
std = torch.clip(std, min=self.min_sampling_denoising_std)
|
||||
@ -136,6 +96,8 @@ class RWRDiffusion(DiffusionModel):
|
||||
x = mean + std * noise
|
||||
|
||||
# clamp action at final step
|
||||
if self.clamp_action is not None and i == len(t_all) - 1:
|
||||
x = torch.clamp(x, -self.clamp_action, self.clamp_action)
|
||||
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
||||
x = torch.clamp(
|
||||
x, -self.final_action_clip_value, self.final_action_clip_value
|
||||
)
|
||||
return x
|
||||
|
@ -34,13 +34,10 @@ class VPGDiffusion(DiffusionModel):
|
||||
ft_denoising_steps_d=0,
|
||||
ft_denoising_steps_t=0,
|
||||
network_path=None,
|
||||
# various clipping
|
||||
randn_clip_value=10,
|
||||
clamp_action=False,
|
||||
# modifying denoising schedule
|
||||
min_sampling_denoising_std=0.1,
|
||||
min_logprob_denoising_std=0.1, # or the scheduler class
|
||||
eps_clip_value=None, # only used with DDIM
|
||||
# DDIM related
|
||||
min_logprob_denoising_std=0.1,
|
||||
# eta in DDIM
|
||||
eta=None,
|
||||
learn_eta=False,
|
||||
**kwargs,
|
||||
@ -60,15 +57,6 @@ class VPGDiffusion(DiffusionModel):
|
||||
self.ft_denoising_steps_t = ft_denoising_steps_t # annealing interval
|
||||
self.ft_denoising_steps_cnt = 0
|
||||
|
||||
# Clip noise for numerical stability in policy gradient
|
||||
self.eps_clip_value = eps_clip_value
|
||||
|
||||
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
||||
self.randn_clip_value = randn_clip_value
|
||||
|
||||
# Whether to clamp sampled action between [-1, 1]
|
||||
self.clamp_action = clamp_action
|
||||
|
||||
# Minimum std used in denoising process when sampling action - helps exploration
|
||||
self.min_sampling_denoising_std = min_sampling_denoising_std
|
||||
|
||||
@ -214,7 +202,6 @@ class VPGDiffusion(DiffusionModel):
|
||||
if deterministic:
|
||||
etas = torch.zeros((x.shape[0], 1, 1)).to(x.device)
|
||||
else:
|
||||
# TODO: eta cond
|
||||
etas = self.eta(cond).unsqueeze(1) # B x 1 x (Da or 1)
|
||||
sigma = (
|
||||
etas
|
||||
@ -310,8 +297,8 @@ class VPGDiffusion(DiffusionModel):
|
||||
x = mean + std * noise
|
||||
|
||||
# clamp action at final step
|
||||
if self.clamp_action and i == len(t_all) - 1:
|
||||
x = torch.clamp(x, -1, 1)
|
||||
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
||||
x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value)
|
||||
|
||||
if return_chain:
|
||||
if not self.use_ddim and t <= self.ft_denoising_steps:
|
||||
|
@ -16,7 +16,6 @@ class RWR_Gaussian(GaussianModel):
|
||||
def __init__(
|
||||
self,
|
||||
actor,
|
||||
randn_clip_value=10,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(network=actor, **kwargs)
|
||||
@ -24,9 +23,6 @@ class RWR_Gaussian(GaussianModel):
|
||||
# assign actor
|
||||
self.actor = self.network
|
||||
|
||||
# Clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
||||
self.randn_clip_value = randn_clip_value
|
||||
|
||||
# override
|
||||
def loss(self, actions, obs, reward_weights):
|
||||
B = len(obs)
|
||||
@ -44,6 +40,5 @@ class RWR_Gaussian(GaussianModel):
|
||||
actions = super().forward(
|
||||
cond=cond,
|
||||
deterministic=deterministic,
|
||||
randn_clip_value=self.randn_clip_value,
|
||||
)
|
||||
return actions
|
||||
|
@ -15,11 +15,9 @@ class VPG_Gaussian(GaussianModel):
|
||||
self,
|
||||
actor,
|
||||
critic,
|
||||
randn_clip_value=10,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(network=actor, **kwargs)
|
||||
self.randn_clip_value = randn_clip_value
|
||||
|
||||
# Value function for obs - simple MLP
|
||||
self.critic = critic.to(self.device)
|
||||
@ -44,7 +42,6 @@ class VPG_Gaussian(GaussianModel):
|
||||
return super().forward(
|
||||
cond=cond,
|
||||
deterministic=deterministic,
|
||||
randn_clip_value=self.randn_clip_value,
|
||||
network_override=self.actor if use_base_policy else None,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user