dppo/model/diffusion/eta.py
2024-09-11 21:09:17 -04:00

167 lines
4.7 KiB
Python

"""
Eta in DDIM.
Can be learned but always fixed to 1 during training and 0 during eval right now.
"""
import torch
from model.common.mlp import MLP
class EtaFixed(torch.nn.Module):
def __init__(
self,
base_eta=0.5,
min_eta=0.1,
max_eta=1.0,
**kwargs,
):
super().__init__()
self.eta_logit = torch.nn.Parameter(torch.ones(1))
self.min = min_eta
self.max = max_eta
# initialize such that eta = base_eta
self.eta_logit.data = torch.atanh(
torch.tensor([2 * (base_eta - min_eta) / (max_eta - min_eta) - 1])
)
def __call__(self, cond):
"""Match input batch size, but do not depend on input"""
sample_data = cond["state"] if "state" in cond else cond["rgb"]
B = len(sample_data)
device = sample_data.device
eta_normalized = torch.tanh(self.eta_logit)
# map to min and max from [-1, 1]
eta = 0.5 * (eta_normalized + 1) * (self.max - self.min) + self.min
return torch.full((B, 1), eta.item()).to(device)
class EtaAction(torch.nn.Module):
def __init__(
self,
action_dim,
base_eta=0.5,
min_eta=0.1,
max_eta=1.0,
**kwargs,
):
super().__init__()
# initialize such that eta = base_eta
self.eta_logit = torch.nn.Parameter(
torch.ones(action_dim)
* torch.atanh(
torch.tensor([2 * (base_eta - min_eta) / (max_eta - min_eta) - 1])
)
)
self.min = min_eta
self.max = max_eta
def __call__(self, cond):
"""Match input batch size, but do not depend on input"""
sample_data = cond["state"] if "state" in cond else cond["rgb"]
B = len(sample_data)
device = sample_data.device
eta_normalized = torch.tanh(self.eta_logit)
# map to min and max from [-1, 1]
eta = 0.5 * (eta_normalized + 1) * (self.max - self.min) + self.min
return eta.repeat(B, 1).to(device)
class EtaState(torch.nn.Module):
def __init__(
self,
input_dim,
mlp_dims,
activation_type="ReLU",
out_activation_type="Identity",
base_eta=0.5,
min_eta=0.1,
max_eta=1.0,
gain=1e-2,
**kwargs,
):
super().__init__()
self.base = base_eta
self.min_res = min_eta - base_eta
self.max_res = max_eta - base_eta
self.mlp_res = MLP(
[input_dim] + mlp_dims + [1],
activation_type=activation_type,
out_activation_type=out_activation_type,
)
# initialize such that mlp(x) = 0
for m in self.mlp_res.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_normal_(m.weight, gain=gain)
m.bias.data.fill_(0)
def __call__(self, cond):
if "rgb" in cond:
raise NotImplementedError(
"State-based eta not implemented for image-based training!"
)
# flatten history
B = len(cond["state"])
state = cond["state"].view(B, -1)
# forward pass
eta_res = self.mlp_res(state)
eta_res = torch.tanh(eta_res) # [-1, 1]
eta = eta_res + self.base # [0, 2]
return torch.clamp(eta, self.min_res + self.base, self.max_res + self.base)
class EtaStateAction(torch.nn.Module):
def __init__(
self,
input_dim,
mlp_dims,
action_dim,
activation_type="ReLU",
out_activation_type="Identity",
base_eta=1,
min_eta=1e-3,
max_eta=2,
gain=1e-2,
**kwargs,
):
super().__init__()
self.base = base_eta
self.min_res = min_eta - base_eta
self.max_res = max_eta - base_eta
self.mlp_res = MLP(
[input_dim] + mlp_dims + [action_dim],
activation_type=activation_type,
out_activation_type=out_activation_type,
)
# initialize such that mlp(x) = 0
for m in self.mlp_res.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_normal_(m.weight, gain=gain)
m.bias.data.fill_(0)
def __call__(self, cond):
if "rgb" in cond:
raise NotImplementedError(
"State-action-based eta not implemented for image-based training!"
)
# flatten history
B = len(cond["state"])
state = cond["state"].view(B, -1)
# forward pass
eta_res = self.mlp_res(state)
eta_res = torch.tanh(eta_res) # [-1, 1]
eta = eta_res + self.base
return torch.clamp(eta, self.min_res + self.base, self.max_res + self.base)