dppo/model/common/mlp_gaussian.py
Allen Z. Ren dc8e0c9edc
v0.6 (#18)
* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* Add Proficient Human (PH) Configs and Pipeline (#16)

* fix missing cfg

* add ph config

* fix how terminated flags are added to buffer in ibrl

* add ph config

* offline calql for 1M gradient updates

* bug fix: number of calql online gradient steps is the number of new transitions collected

* add sample config for DPPO with ta=1

* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* fix diffusion loss when predicting initial noise

* fix dppo inds

* fix typo

* remove print statement

---------

Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
Co-authored-by: allenzren <allen.ren@princeton.edu>

* update robomimic configs

* better calql formulation

* optimize calql and ibrl training

* optimize data transfer in ppo agents

* add kitchen configs

* re-organize config folders, rerun calql and rlpd

* add scratch gym locomotion configs

* add kitchen installation dependencies

* use truncated for termination in furniture env

* update furniture and gym configs

* update README and dependencies with kitchen

* add url for new data and checkpoints

* update demo RL configs

* update batch sizes for furniture unet configs

* raise error about dropout in residual mlp

* fix observation bug in bc loss

---------

Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com>
Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
2024-10-30 19:58:06 -04:00

283 lines
9.7 KiB
Python

"""
MLP models for Gaussian policy.
"""
import torch
import torch.nn as nn
import einops
from copy import deepcopy
from model.common.mlp import MLP, ResidualMLP
from model.common.modules import SpatialEmb, RandomShiftsAug
class Gaussian_VisionMLP(nn.Module):
"""With ViT backbone"""
def __init__(
self,
backbone,
action_dim,
horizon_steps,
cond_dim,
img_cond_steps=1,
mlp_dims=[256, 256, 256],
activation_type="Mish",
residual_style=False,
use_layernorm=False,
fixed_std=None,
learn_fixed_std=False,
std_min=0.01,
std_max=1,
spatial_emb=0,
visual_feature_dim=128,
dropout=0,
num_img=1,
augment=False,
):
super().__init__()
# vision
self.backbone = backbone
if augment:
self.aug = RandomShiftsAug(pad=4)
self.augment = augment
self.num_img = num_img
self.img_cond_steps = img_cond_steps
if spatial_emb > 0:
assert spatial_emb > 1, "this is the dimension"
if num_img > 1:
self.compress1 = SpatialEmb(
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,
)
self.compress2 = deepcopy(self.compress1)
else: # TODO: clean up
self.compress = SpatialEmb(
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,
)
visual_feature_dim = spatial_emb * num_img
else:
self.compress = nn.Sequential(
nn.Linear(self.backbone.repr_dim, visual_feature_dim),
nn.LayerNorm(visual_feature_dim),
nn.Dropout(dropout),
nn.ReLU(),
)
# head
self.action_dim = action_dim
self.horizon_steps = horizon_steps
input_dim = visual_feature_dim + cond_dim
output_dim = action_dim * horizon_steps
if residual_style:
model = ResidualMLP
else:
model = MLP
self.mlp_mean = model(
[input_dim] + mlp_dims + [output_dim],
activation_type=activation_type,
out_activation_type="Identity",
use_layernorm=use_layernorm,
)
if fixed_std is None:
self.mlp_logvar = MLP(
[input_dim] + mlp_dims[-1:] + [output_dim],
activation_type=activation_type,
out_activation_type="Identity",
use_layernorm=use_layernorm,
)
elif learn_fixed_std: # initialize to fixed_std
self.logvar = torch.nn.Parameter(
torch.log(torch.tensor([fixed_std**2 for _ in range(action_dim)])),
requires_grad=True,
)
self.logvar_min = torch.nn.Parameter(
torch.log(torch.tensor(std_min**2)), requires_grad=False
)
self.logvar_max = torch.nn.Parameter(
torch.log(torch.tensor(std_max**2)), requires_grad=False
)
self.use_fixed_std = fixed_std is not None
self.fixed_std = fixed_std
self.learn_fixed_std = learn_fixed_std
def forward(self, cond):
B = len(cond["rgb"])
device = cond["rgb"].device
_, T_rgb, C, H, W = cond["rgb"].shape
# flatten history
state = cond["state"].view(B, -1)
# Take recent images --- sometimes we want to use fewer img_cond_steps than cond_steps (e.g., 1 image but 3 prio)
rgb = cond["rgb"][:, -self.img_cond_steps :]
# concatenate images in cond by channels
if self.num_img > 1:
rgb = rgb.reshape(B, T_rgb, self.num_img, 3, H, W)
rgb = einops.rearrange(rgb, "b t n c h w -> b n (t c) h w")
else:
rgb = einops.rearrange(rgb, "b t c h w -> b (t c) h w")
# convert rgb to float32 for augmentation
rgb = rgb.float()
# get vit output - pass in two images separately
if self.num_img > 1: # TODO: properly handle multiple images
rgb1 = rgb[:, 0]
rgb2 = rgb[:, 1]
if self.augment:
rgb1 = self.aug(rgb1)
rgb2 = self.aug(rgb2)
feat1 = self.backbone(rgb1)
feat2 = self.backbone(rgb2)
feat1 = self.compress1.forward(feat1, state)
feat2 = self.compress2.forward(feat2, state)
feat = torch.cat([feat1, feat2], dim=-1)
else: # single image
if self.augment:
rgb = self.aug(rgb) # uint8 -> float32
feat = self.backbone(rgb)
# compress
if isinstance(self.compress, SpatialEmb):
feat = self.compress.forward(feat, state)
else:
feat = feat.flatten(1, -1)
feat = self.compress(feat)
# mlp
x_encoded = torch.cat([feat, state], dim=-1)
out_mean = self.mlp_mean(x_encoded)
out_mean = torch.tanh(out_mean).view(
B, self.horizon_steps * self.action_dim
) # tanh squashing in [-1, 1]
if self.learn_fixed_std:
out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max)
out_scale = torch.exp(0.5 * out_logvar)
out_scale = out_scale.view(1, self.action_dim)
out_scale = out_scale.repeat(B, self.horizon_steps)
elif self.use_fixed_std:
out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std
else:
out_logvar = self.mlp_logvar(x_encoded).view(
B, self.horizon_steps * self.action_dim
)
out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max)
out_scale = torch.exp(0.5 * out_logvar)
return out_mean, out_scale
class Gaussian_MLP(nn.Module):
def __init__(
self,
action_dim,
horizon_steps,
cond_dim,
mlp_dims=[256, 256, 256],
activation_type="Mish",
tanh_output=True, # sometimes we want to apply tanh after sampling instead of here, e.g., in SAC
residual_style=False,
use_layernorm=False,
dropout=0.0,
fixed_std=None,
learn_fixed_std=False,
std_min=0.01,
std_max=1,
):
super().__init__()
self.action_dim = action_dim
self.horizon_steps = horizon_steps
input_dim = cond_dim
output_dim = action_dim * horizon_steps
if residual_style:
model = ResidualMLP
else:
model = MLP
if fixed_std is None:
# learning std
self.mlp_base = model(
[input_dim] + mlp_dims,
activation_type=activation_type,
out_activation_type=activation_type,
use_layernorm=use_layernorm,
use_layernorm_final=use_layernorm,
dropout=dropout,
)
self.mlp_mean = MLP(
mlp_dims[-1:] + [output_dim],
out_activation_type="Identity",
)
self.mlp_logvar = MLP(
mlp_dims[-1:] + [output_dim],
out_activation_type="Identity",
)
else:
# no separate head for mean and std
self.mlp_mean = model(
[input_dim] + mlp_dims + [output_dim],
activation_type=activation_type,
out_activation_type="Identity",
use_layernorm=use_layernorm,
dropout=dropout,
)
if learn_fixed_std:
# initialize to fixed_std
self.logvar = torch.nn.Parameter(
torch.log(torch.tensor([fixed_std**2 for _ in range(action_dim)])),
requires_grad=True,
)
self.logvar_min = torch.nn.Parameter(
torch.log(torch.tensor(std_min**2)), requires_grad=False
)
self.logvar_max = torch.nn.Parameter(
torch.log(torch.tensor(std_max**2)), requires_grad=False
)
self.use_fixed_std = fixed_std is not None
self.fixed_std = fixed_std
self.learn_fixed_std = learn_fixed_std
self.tanh_output = tanh_output
def forward(self, cond):
B = len(cond["state"])
device = cond["state"].device
# flatten history
state = cond["state"].view(B, -1)
# mlp
if hasattr(self, "mlp_base"):
state = self.mlp_base(state)
out_mean = self.mlp_mean(state)
if self.tanh_output:
out_mean = torch.tanh(out_mean)
out_mean = out_mean.view(B, self.horizon_steps * self.action_dim)
if self.learn_fixed_std:
out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max)
out_scale = torch.exp(0.5 * out_logvar)
out_scale = out_scale.view(1, self.action_dim)
out_scale = out_scale.repeat(B, self.horizon_steps)
elif self.use_fixed_std:
out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std
else:
out_logvar = self.mlp_logvar(state).view(
B, self.horizon_steps * self.action_dim
)
out_logvar = torch.tanh(out_logvar)
out_logvar = self.logvar_min + 0.5 * (self.logvar_max - self.logvar_min) * (
out_logvar + 1
) # put back to full range
out_scale = torch.exp(0.5 * out_logvar)
return out_mean, out_scale