add minor docs to diffusion classes and clean up some args
This commit is contained in:
parent
ef5b14f820
commit
bc52beca1e
@ -18,6 +18,7 @@ class GaussianModel(torch.nn.Module):
|
|||||||
horizon_steps,
|
horizon_steps,
|
||||||
network_path=None,
|
network_path=None,
|
||||||
device="cuda:0",
|
device="cuda:0",
|
||||||
|
randn_clip_value=10,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -36,6 +37,9 @@ class GaussianModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.horizon_steps = horizon_steps
|
self.horizon_steps = horizon_steps
|
||||||
|
|
||||||
|
# Clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
||||||
|
self.randn_clip_value = randn_clip_value
|
||||||
|
|
||||||
def loss(
|
def loss(
|
||||||
self,
|
self,
|
||||||
true_action,
|
true_action,
|
||||||
@ -75,7 +79,6 @@ class GaussianModel(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
cond,
|
cond,
|
||||||
deterministic=False,
|
deterministic=False,
|
||||||
randn_clip_value=10,
|
|
||||||
network_override=None,
|
network_override=None,
|
||||||
):
|
):
|
||||||
B = len(cond["state"]) if "state" in cond else len(cond["rgb"])
|
B = len(cond["state"]) if "state" in cond else len(cond["rgb"])
|
||||||
@ -87,7 +90,7 @@ class GaussianModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
sampled_action = dist.sample()
|
sampled_action = dist.sample()
|
||||||
sampled_action.clamp_(
|
sampled_action.clamp_(
|
||||||
dist.loc - randn_clip_value * dist.scale,
|
dist.loc - self.randn_clip_value * dist.scale,
|
||||||
dist.loc + randn_clip_value * dist.scale,
|
dist.loc + self.randn_clip_value * dist.scale,
|
||||||
)
|
)
|
||||||
return sampled_action.view(B, T, -1)
|
return sampled_action.view(B, T, -1)
|
||||||
|
@ -19,8 +19,7 @@ class DIPODiffusion(DiffusionModel):
|
|||||||
actor,
|
actor,
|
||||||
critic,
|
critic,
|
||||||
use_ddim=False,
|
use_ddim=False,
|
||||||
randn_clip_value=10,
|
# modifying denoising schedule
|
||||||
clamp_action=False,
|
|
||||||
min_sampling_denoising_std=0.1,
|
min_sampling_denoising_std=0.1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -34,12 +33,6 @@ class DIPODiffusion(DiffusionModel):
|
|||||||
# Minimum std used in denoising process when sampling action - helps exploration
|
# Minimum std used in denoising process when sampling action - helps exploration
|
||||||
self.min_sampling_denoising_std = min_sampling_denoising_std
|
self.min_sampling_denoising_std = min_sampling_denoising_std
|
||||||
|
|
||||||
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
|
||||||
self.randn_clip_value = randn_clip_value
|
|
||||||
|
|
||||||
# Whether to clamp sampled action between [-1, 1]
|
|
||||||
self.clamp_action = clamp_action
|
|
||||||
|
|
||||||
# ---------- RL training ----------#
|
# ---------- RL training ----------#
|
||||||
|
|
||||||
def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma):
|
def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma):
|
||||||
@ -110,6 +103,8 @@ class DIPODiffusion(DiffusionModel):
|
|||||||
x = mean + std * noise
|
x = mean + std * noise
|
||||||
|
|
||||||
# clamp action at final step
|
# clamp action at final step
|
||||||
if self.clamp_action and i == len(t_all) - 1:
|
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
||||||
x = torch.clamp(x, -1, 1)
|
x = torch.clamp(
|
||||||
|
x, -self.final_action_clip_value, self.final_action_clip_value
|
||||||
|
)
|
||||||
return x
|
return x
|
||||||
|
@ -20,8 +20,7 @@ class DQLDiffusion(DiffusionModel):
|
|||||||
actor,
|
actor,
|
||||||
critic,
|
critic,
|
||||||
use_ddim=False,
|
use_ddim=False,
|
||||||
randn_clip_value=10,
|
# modifying denoising schedule
|
||||||
clamp_action=False,
|
|
||||||
min_sampling_denoising_std=0.1,
|
min_sampling_denoising_std=0.1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -35,12 +34,6 @@ class DQLDiffusion(DiffusionModel):
|
|||||||
# Minimum std used in denoising process when sampling action - helps exploration
|
# Minimum std used in denoising process when sampling action - helps exploration
|
||||||
self.min_sampling_denoising_std = min_sampling_denoising_std
|
self.min_sampling_denoising_std = min_sampling_denoising_std
|
||||||
|
|
||||||
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
|
||||||
self.randn_clip_value = randn_clip_value
|
|
||||||
|
|
||||||
# Whether to clamp sampled action between [-1, 1]
|
|
||||||
self.clamp_action = clamp_action
|
|
||||||
|
|
||||||
# ---------- RL training ----------#
|
# ---------- RL training ----------#
|
||||||
|
|
||||||
def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma):
|
def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma):
|
||||||
@ -110,7 +103,7 @@ class DQLDiffusion(DiffusionModel):
|
|||||||
# Determine the noise level
|
# Determine the noise level
|
||||||
if deterministic and t == 0:
|
if deterministic and t == 0:
|
||||||
std = torch.zeros_like(std)
|
std = torch.zeros_like(std)
|
||||||
elif deterministic: # For DDPM, sample with noise
|
elif deterministic:
|
||||||
std = torch.clip(std, min=1e-3)
|
std = torch.clip(std, min=1e-3)
|
||||||
else:
|
else:
|
||||||
std = torch.clip(std, min=self.min_sampling_denoising_std)
|
std = torch.clip(std, min=self.min_sampling_denoising_std)
|
||||||
@ -120,8 +113,10 @@ class DQLDiffusion(DiffusionModel):
|
|||||||
x = mean + std * noise
|
x = mean + std * noise
|
||||||
|
|
||||||
# clamp action at final step
|
# clamp action at final step
|
||||||
if self.clamp_action and i == len(t_all) - 1:
|
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
||||||
x = torch.clamp(x, -1, 1)
|
x = torch.clamp(
|
||||||
|
x, -self.final_action_clip_value, self.final_action_clip_value
|
||||||
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_train(
|
def forward_train(
|
||||||
@ -160,6 +155,6 @@ class DQLDiffusion(DiffusionModel):
|
|||||||
x = mean + std * noise
|
x = mean + std * noise
|
||||||
|
|
||||||
# clamp action at final step
|
# clamp action at final step
|
||||||
if self.clamp_action and i == len(t_all) - 1:
|
if self.final_action_clip_value and i == len(t_all) - 1:
|
||||||
x = torch.clamp(x, -1, 1)
|
x = torch.clamp(x, -1, 1)
|
||||||
return x
|
return x
|
||||||
|
@ -104,6 +104,7 @@ class IDQLDiffusion(RWRDiffusion):
|
|||||||
cond,
|
cond,
|
||||||
t,
|
t,
|
||||||
):
|
):
|
||||||
|
"""not reward-weighted, same as diffusion.py"""
|
||||||
device = x_start.device
|
device = x_start.device
|
||||||
|
|
||||||
# Forward process
|
# Forward process
|
||||||
|
@ -11,7 +11,7 @@ log = logging.getLogger(__name__)
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from model.diffusion.diffusion import DiffusionModel
|
from model.diffusion.diffusion import DiffusionModel
|
||||||
from model.diffusion.sampling import make_timesteps, extract
|
from model.diffusion.sampling import make_timesteps
|
||||||
|
|
||||||
|
|
||||||
class RWRDiffusion(DiffusionModel):
|
class RWRDiffusion(DiffusionModel):
|
||||||
@ -19,21 +19,13 @@ class RWRDiffusion(DiffusionModel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_ddim=False,
|
use_ddim=False,
|
||||||
# various clipping
|
# modifying denoising schedule
|
||||||
randn_clip_value=10,
|
|
||||||
clamp_action=None,
|
|
||||||
min_sampling_denoising_std=0.1,
|
min_sampling_denoising_std=0.1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(use_ddim=use_ddim, **kwargs)
|
super().__init__(use_ddim=use_ddim, **kwargs)
|
||||||
assert not self.use_ddim, "RWR does not support DDIM"
|
assert not self.use_ddim, "RWR does not support DDIM"
|
||||||
|
|
||||||
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
|
||||||
self.randn_clip_value = randn_clip_value
|
|
||||||
|
|
||||||
# Action clamp range
|
|
||||||
self.clamp_action = clamp_action
|
|
||||||
|
|
||||||
# Minimum std used in denoising process when sampling action - helps exploration
|
# Minimum std used in denoising process when sampling action - helps exploration
|
||||||
self.min_sampling_denoising_std = min_sampling_denoising_std
|
self.min_sampling_denoising_std = min_sampling_denoising_std
|
||||||
|
|
||||||
@ -47,6 +39,7 @@ class RWRDiffusion(DiffusionModel):
|
|||||||
rewards,
|
rewards,
|
||||||
t,
|
t,
|
||||||
):
|
):
|
||||||
|
"""reward-weighted"""
|
||||||
device = x_start.device
|
device = x_start.device
|
||||||
|
|
||||||
# Forward process
|
# Forward process
|
||||||
@ -67,40 +60,6 @@ class RWRDiffusion(DiffusionModel):
|
|||||||
|
|
||||||
# ---------- Sampling ----------#
|
# ---------- Sampling ----------#
|
||||||
|
|
||||||
# override
|
|
||||||
def p_mean_var(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
cond,
|
|
||||||
):
|
|
||||||
noise = self.network(x, t, cond=cond)
|
|
||||||
|
|
||||||
# Predict x_0
|
|
||||||
if self.predict_epsilon:
|
|
||||||
"""
|
|
||||||
x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε
|
|
||||||
"""
|
|
||||||
x_recon = (
|
|
||||||
extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
|
||||||
- extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise
|
|
||||||
)
|
|
||||||
else: # directly predicting x₀
|
|
||||||
x_recon = noise
|
|
||||||
if self.denoised_clip_value is not None:
|
|
||||||
x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value)
|
|
||||||
|
|
||||||
# Get mu
|
|
||||||
"""
|
|
||||||
μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
|
|
||||||
"""
|
|
||||||
mu = (
|
|
||||||
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
|
|
||||||
+ extract(self.ddpm_mu_coef2, t, x.shape) * x
|
|
||||||
)
|
|
||||||
logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
|
|
||||||
return mu, logvar
|
|
||||||
|
|
||||||
# override
|
# override
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@ -108,6 +67,7 @@ class RWRDiffusion(DiffusionModel):
|
|||||||
cond,
|
cond,
|
||||||
deterministic=False,
|
deterministic=False,
|
||||||
):
|
):
|
||||||
|
"""Modifying denoising schedule"""
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
B = len(cond["state"])
|
B = len(cond["state"])
|
||||||
|
|
||||||
@ -126,7 +86,7 @@ class RWRDiffusion(DiffusionModel):
|
|||||||
# Determine noise level
|
# Determine noise level
|
||||||
if deterministic and t == 0:
|
if deterministic and t == 0:
|
||||||
std = torch.zeros_like(std)
|
std = torch.zeros_like(std)
|
||||||
elif deterministic: # For DDPM, sample with noise
|
elif deterministic:
|
||||||
std = torch.clip(std, min=1e-3)
|
std = torch.clip(std, min=1e-3)
|
||||||
else:
|
else:
|
||||||
std = torch.clip(std, min=self.min_sampling_denoising_std)
|
std = torch.clip(std, min=self.min_sampling_denoising_std)
|
||||||
@ -136,6 +96,8 @@ class RWRDiffusion(DiffusionModel):
|
|||||||
x = mean + std * noise
|
x = mean + std * noise
|
||||||
|
|
||||||
# clamp action at final step
|
# clamp action at final step
|
||||||
if self.clamp_action is not None and i == len(t_all) - 1:
|
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
||||||
x = torch.clamp(x, -self.clamp_action, self.clamp_action)
|
x = torch.clamp(
|
||||||
|
x, -self.final_action_clip_value, self.final_action_clip_value
|
||||||
|
)
|
||||||
return x
|
return x
|
||||||
|
@ -34,13 +34,10 @@ class VPGDiffusion(DiffusionModel):
|
|||||||
ft_denoising_steps_d=0,
|
ft_denoising_steps_d=0,
|
||||||
ft_denoising_steps_t=0,
|
ft_denoising_steps_t=0,
|
||||||
network_path=None,
|
network_path=None,
|
||||||
# various clipping
|
# modifying denoising schedule
|
||||||
randn_clip_value=10,
|
|
||||||
clamp_action=False,
|
|
||||||
min_sampling_denoising_std=0.1,
|
min_sampling_denoising_std=0.1,
|
||||||
min_logprob_denoising_std=0.1, # or the scheduler class
|
min_logprob_denoising_std=0.1,
|
||||||
eps_clip_value=None, # only used with DDIM
|
# eta in DDIM
|
||||||
# DDIM related
|
|
||||||
eta=None,
|
eta=None,
|
||||||
learn_eta=False,
|
learn_eta=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -60,15 +57,6 @@ class VPGDiffusion(DiffusionModel):
|
|||||||
self.ft_denoising_steps_t = ft_denoising_steps_t # annealing interval
|
self.ft_denoising_steps_t = ft_denoising_steps_t # annealing interval
|
||||||
self.ft_denoising_steps_cnt = 0
|
self.ft_denoising_steps_cnt = 0
|
||||||
|
|
||||||
# Clip noise for numerical stability in policy gradient
|
|
||||||
self.eps_clip_value = eps_clip_value
|
|
||||||
|
|
||||||
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
|
||||||
self.randn_clip_value = randn_clip_value
|
|
||||||
|
|
||||||
# Whether to clamp sampled action between [-1, 1]
|
|
||||||
self.clamp_action = clamp_action
|
|
||||||
|
|
||||||
# Minimum std used in denoising process when sampling action - helps exploration
|
# Minimum std used in denoising process when sampling action - helps exploration
|
||||||
self.min_sampling_denoising_std = min_sampling_denoising_std
|
self.min_sampling_denoising_std = min_sampling_denoising_std
|
||||||
|
|
||||||
@ -214,7 +202,6 @@ class VPGDiffusion(DiffusionModel):
|
|||||||
if deterministic:
|
if deterministic:
|
||||||
etas = torch.zeros((x.shape[0], 1, 1)).to(x.device)
|
etas = torch.zeros((x.shape[0], 1, 1)).to(x.device)
|
||||||
else:
|
else:
|
||||||
# TODO: eta cond
|
|
||||||
etas = self.eta(cond).unsqueeze(1) # B x 1 x (Da or 1)
|
etas = self.eta(cond).unsqueeze(1) # B x 1 x (Da or 1)
|
||||||
sigma = (
|
sigma = (
|
||||||
etas
|
etas
|
||||||
@ -310,8 +297,8 @@ class VPGDiffusion(DiffusionModel):
|
|||||||
x = mean + std * noise
|
x = mean + std * noise
|
||||||
|
|
||||||
# clamp action at final step
|
# clamp action at final step
|
||||||
if self.clamp_action and i == len(t_all) - 1:
|
if self.final_action_clip_value is not None and i == len(t_all) - 1:
|
||||||
x = torch.clamp(x, -1, 1)
|
x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value)
|
||||||
|
|
||||||
if return_chain:
|
if return_chain:
|
||||||
if not self.use_ddim and t <= self.ft_denoising_steps:
|
if not self.use_ddim and t <= self.ft_denoising_steps:
|
||||||
|
@ -16,7 +16,6 @@ class RWR_Gaussian(GaussianModel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
actor,
|
actor,
|
||||||
randn_clip_value=10,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(network=actor, **kwargs)
|
super().__init__(network=actor, **kwargs)
|
||||||
@ -24,9 +23,6 @@ class RWR_Gaussian(GaussianModel):
|
|||||||
# assign actor
|
# assign actor
|
||||||
self.actor = self.network
|
self.actor = self.network
|
||||||
|
|
||||||
# Clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
|
||||||
self.randn_clip_value = randn_clip_value
|
|
||||||
|
|
||||||
# override
|
# override
|
||||||
def loss(self, actions, obs, reward_weights):
|
def loss(self, actions, obs, reward_weights):
|
||||||
B = len(obs)
|
B = len(obs)
|
||||||
@ -44,6 +40,5 @@ class RWR_Gaussian(GaussianModel):
|
|||||||
actions = super().forward(
|
actions = super().forward(
|
||||||
cond=cond,
|
cond=cond,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
randn_clip_value=self.randn_clip_value,
|
|
||||||
)
|
)
|
||||||
return actions
|
return actions
|
||||||
|
@ -15,11 +15,9 @@ class VPG_Gaussian(GaussianModel):
|
|||||||
self,
|
self,
|
||||||
actor,
|
actor,
|
||||||
critic,
|
critic,
|
||||||
randn_clip_value=10,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(network=actor, **kwargs)
|
super().__init__(network=actor, **kwargs)
|
||||||
self.randn_clip_value = randn_clip_value
|
|
||||||
|
|
||||||
# Value function for obs - simple MLP
|
# Value function for obs - simple MLP
|
||||||
self.critic = critic.to(self.device)
|
self.critic = critic.to(self.device)
|
||||||
@ -44,7 +42,6 @@ class VPG_Gaussian(GaussianModel):
|
|||||||
return super().forward(
|
return super().forward(
|
||||||
cond=cond,
|
cond=cond,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
randn_clip_value=self.randn_clip_value,
|
|
||||||
network_override=self.actor if use_base_policy else None,
|
network_override=self.actor if use_base_policy else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user