dppo/model/common/mlp_gaussian.py
2024-09-16 17:55:31 -04:00

261 lines
8.9 KiB
Python

"""
MLP models for Gaussian policy.
"""
import torch
import torch.nn as nn
import einops
from copy import deepcopy
from model.common.mlp import MLP, ResidualMLP
from model.common.modules import SpatialEmb, RandomShiftsAug
class Gaussian_VisionMLP(nn.Module):
"""With ViT backbone"""
def __init__(
self,
backbone,
transition_dim,
horizon_steps,
cond_dim,
img_cond_steps=1,
mlp_dims=[256, 256, 256],
activation_type="Mish",
residual_style=False,
use_layernorm=False,
fixed_std=None,
learn_fixed_std=False,
std_min=0.01,
std_max=1,
spatial_emb=0,
visual_feature_dim=128,
dropout=0,
num_img=1,
augment=False,
):
super().__init__()
# vision
self.backbone = backbone
if augment:
self.aug = RandomShiftsAug(pad=4)
self.augment = augment
self.num_img = num_img
self.img_cond_steps = img_cond_steps
if spatial_emb > 0:
assert spatial_emb > 1, "this is the dimension"
if num_img > 1:
self.compress1 = SpatialEmb(
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,
)
self.compress2 = deepcopy(self.compress1)
else: # TODO: clean up
self.compress = SpatialEmb(
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,
)
visual_feature_dim = spatial_emb * num_img
else:
self.compress = nn.Sequential(
nn.Linear(self.backbone.repr_dim, visual_feature_dim),
nn.LayerNorm(visual_feature_dim),
nn.Dropout(dropout),
nn.ReLU(),
)
# head
self.transition_dim = transition_dim
self.horizon_steps = horizon_steps
input_dim = visual_feature_dim + cond_dim
output_dim = transition_dim * horizon_steps
if residual_style:
model = ResidualMLP
else:
model = MLP
self.mlp_mean = model(
[input_dim] + mlp_dims + [output_dim],
activation_type=activation_type,
out_activation_type="Identity",
use_layernorm=use_layernorm,
)
if fixed_std is None:
self.mlp_logvar = MLP(
[input_dim] + mlp_dims[-1:] + [output_dim],
activation_type=activation_type,
out_activation_type="Identity",
use_layernorm=use_layernorm,
)
elif learn_fixed_std: # initialize to fixed_std
self.logvar = torch.nn.Parameter(
torch.log(torch.tensor([fixed_std**2 for _ in range(transition_dim)])),
requires_grad=True,
)
self.logvar_min = torch.nn.Parameter(
torch.log(torch.tensor(std_min**2)), requires_grad=False
)
self.logvar_max = torch.nn.Parameter(
torch.log(torch.tensor(std_max**2)), requires_grad=False
)
self.use_fixed_std = fixed_std is not None
self.fixed_std = fixed_std
self.learn_fixed_std = learn_fixed_std
def forward(self, cond):
B = len(cond["rgb"])
device = cond["rgb"].device
_, T_rgb, C, H, W = cond["rgb"].shape
# flatten history
state = cond["state"].view(B, -1)
# Take recent images --- sometimes we want to use fewer img_cond_steps than cond_steps (e.g., 1 image but 3 prio)
rgb = cond["rgb"][:, -self.img_cond_steps :]
# concatenate images in cond by channels
if self.num_img > 1:
rgb = rgb.reshape(B, T_rgb, self.num_img, 3, H, W)
rgb = einops.rearrange(rgb, "b t n c h w -> b n (t c) h w")
else:
rgb = einops.rearrange(rgb, "b t c h w -> b (t c) h w")
# convert rgb to float32 for augmentation
rgb = rgb.float()
# get vit output - pass in two images separately
if self.num_img > 1: # TODO: properly handle multiple images
rgb1 = rgb[:, 0]
rgb2 = rgb[:, 1]
if self.augment:
rgb1 = self.aug(rgb1)
rgb2 = self.aug(rgb2)
feat1 = self.backbone(rgb1)
feat2 = self.backbone(rgb2)
feat1 = self.compress1.forward(feat1, state)
feat2 = self.compress2.forward(feat2, state)
feat = torch.cat([feat1, feat2], dim=-1)
else: # single image
if self.augment:
rgb = self.aug(rgb) # uint8 -> float32
feat = self.backbone(rgb)
# compress
if isinstance(self.compress, SpatialEmb):
feat = self.compress.forward(feat, state)
else:
feat = feat.flatten(1, -1)
feat = self.compress(feat)
# mlp
x_encoded = torch.cat([feat, state], dim=-1)
out_mean = self.mlp_mean(x_encoded)
out_mean = torch.tanh(out_mean).view(
B, self.horizon_steps * self.transition_dim
) # tanh squashing in [-1, 1]
if self.learn_fixed_std:
out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max)
out_scale = torch.exp(0.5 * out_logvar)
out_scale = out_scale.view(1, self.transition_dim)
out_scale = out_scale.repeat(B, self.horizon_steps)
elif self.use_fixed_std:
out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std
else:
out_logvar = self.mlp_logvar(x_encoded).view(
B, self.horizon_steps * self.transition_dim
)
out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max)
out_scale = torch.exp(0.5 * out_logvar)
return out_mean, out_scale
class Gaussian_MLP(nn.Module):
def __init__(
self,
transition_dim,
horizon_steps,
cond_dim,
mlp_dims=[256, 256, 256],
activation_type="Mish",
residual_style=False,
use_layernorm=False,
fixed_std=None,
learn_fixed_std=False,
std_min=0.01,
std_max=1,
):
super().__init__()
self.transition_dim = transition_dim
self.horizon_steps = horizon_steps
input_dim = cond_dim
output_dim = transition_dim * horizon_steps
if residual_style:
model = ResidualMLP
else:
model = MLP
self.mlp_mean = model(
[input_dim] + mlp_dims + [output_dim],
activation_type=activation_type,
out_activation_type="Identity",
use_layernorm=use_layernorm,
)
if fixed_std is None:
self.mlp_logvar = MLP(
[input_dim] + mlp_dims[-1:] + [output_dim],
activation_type=activation_type,
out_activation_type="Identity",
use_layernorm=use_layernorm,
)
elif learn_fixed_std: # initialize to fixed_std
self.logvar = torch.nn.Parameter(
torch.log(torch.tensor([fixed_std**2 for _ in range(transition_dim)])),
requires_grad=True,
)
self.logvar_min = torch.nn.Parameter(
torch.log(torch.tensor(std_min**2)), requires_grad=False
)
self.logvar_max = torch.nn.Parameter(
torch.log(torch.tensor(std_max**2)), requires_grad=False
)
self.use_fixed_std = fixed_std is not None
self.fixed_std = fixed_std
self.learn_fixed_std = learn_fixed_std
def forward(self, cond):
B = len(cond["state"])
device = cond["state"].device
# flatten history
state = cond["state"].view(B, -1)
# mlp
out_mean = self.mlp_mean(state)
out_mean = torch.tanh(out_mean).view(
B, self.horizon_steps * self.transition_dim
) # tanh squashing in [-1, 1]
if self.learn_fixed_std:
out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max)
out_scale = torch.exp(0.5 * out_logvar)
out_scale = out_scale.view(1, self.transition_dim)
out_scale = out_scale.repeat(B, self.horizon_steps)
elif self.use_fixed_std:
out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std
else:
out_logvar = self.mlp_logvar(state).view(
B, self.horizon_steps * self.transition_dim
)
out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max)
out_scale = torch.exp(0.5 * out_logvar)
return out_mean, out_scale