* Sampling over both env and denoising steps in DPPO updates (#13) * sample one from each chain * full random sampling * Add Proficient Human (PH) Configs and Pipeline (#16) * fix missing cfg * add ph config * fix how terminated flags are added to buffer in ibrl * add ph config * offline calql for 1M gradient updates * bug fix: number of calql online gradient steps is the number of new transitions collected * add sample config for DPPO with ta=1 * Sampling over both env and denoising steps in DPPO updates (#13) * sample one from each chain * full random sampling * fix diffusion loss when predicting initial noise * fix dppo inds * fix typo * remove print statement --------- Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu> Co-authored-by: allenzren <allen.ren@princeton.edu> * update robomimic configs * better calql formulation * optimize calql and ibrl training * optimize data transfer in ppo agents * add kitchen configs * re-organize config folders, rerun calql and rlpd * add scratch gym locomotion configs * add kitchen installation dependencies * use truncated for termination in furniture env * update furniture and gym configs * update README and dependencies with kitchen * add url for new data and checkpoints * update demo RL configs * update batch sizes for furniture unet configs * raise error about dropout in residual mlp * fix observation bug in bc loss --------- Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com> Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
283 lines
9.7 KiB
Python
283 lines
9.7 KiB
Python
"""
|
|
MLP models for Gaussian policy.
|
|
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import einops
|
|
from copy import deepcopy
|
|
|
|
from model.common.mlp import MLP, ResidualMLP
|
|
from model.common.modules import SpatialEmb, RandomShiftsAug
|
|
|
|
|
|
class Gaussian_VisionMLP(nn.Module):
|
|
"""With ViT backbone"""
|
|
|
|
def __init__(
|
|
self,
|
|
backbone,
|
|
action_dim,
|
|
horizon_steps,
|
|
cond_dim,
|
|
img_cond_steps=1,
|
|
mlp_dims=[256, 256, 256],
|
|
activation_type="Mish",
|
|
residual_style=False,
|
|
use_layernorm=False,
|
|
fixed_std=None,
|
|
learn_fixed_std=False,
|
|
std_min=0.01,
|
|
std_max=1,
|
|
spatial_emb=0,
|
|
visual_feature_dim=128,
|
|
dropout=0,
|
|
num_img=1,
|
|
augment=False,
|
|
):
|
|
super().__init__()
|
|
|
|
# vision
|
|
self.backbone = backbone
|
|
if augment:
|
|
self.aug = RandomShiftsAug(pad=4)
|
|
self.augment = augment
|
|
self.num_img = num_img
|
|
self.img_cond_steps = img_cond_steps
|
|
if spatial_emb > 0:
|
|
assert spatial_emb > 1, "this is the dimension"
|
|
if num_img > 1:
|
|
self.compress1 = SpatialEmb(
|
|
num_patch=self.backbone.num_patch,
|
|
patch_dim=self.backbone.patch_repr_dim,
|
|
prop_dim=cond_dim,
|
|
proj_dim=spatial_emb,
|
|
dropout=dropout,
|
|
)
|
|
self.compress2 = deepcopy(self.compress1)
|
|
else: # TODO: clean up
|
|
self.compress = SpatialEmb(
|
|
num_patch=self.backbone.num_patch,
|
|
patch_dim=self.backbone.patch_repr_dim,
|
|
prop_dim=cond_dim,
|
|
proj_dim=spatial_emb,
|
|
dropout=dropout,
|
|
)
|
|
visual_feature_dim = spatial_emb * num_img
|
|
else:
|
|
self.compress = nn.Sequential(
|
|
nn.Linear(self.backbone.repr_dim, visual_feature_dim),
|
|
nn.LayerNorm(visual_feature_dim),
|
|
nn.Dropout(dropout),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# head
|
|
self.action_dim = action_dim
|
|
self.horizon_steps = horizon_steps
|
|
input_dim = visual_feature_dim + cond_dim
|
|
output_dim = action_dim * horizon_steps
|
|
if residual_style:
|
|
model = ResidualMLP
|
|
else:
|
|
model = MLP
|
|
self.mlp_mean = model(
|
|
[input_dim] + mlp_dims + [output_dim],
|
|
activation_type=activation_type,
|
|
out_activation_type="Identity",
|
|
use_layernorm=use_layernorm,
|
|
)
|
|
if fixed_std is None:
|
|
self.mlp_logvar = MLP(
|
|
[input_dim] + mlp_dims[-1:] + [output_dim],
|
|
activation_type=activation_type,
|
|
out_activation_type="Identity",
|
|
use_layernorm=use_layernorm,
|
|
)
|
|
elif learn_fixed_std: # initialize to fixed_std
|
|
self.logvar = torch.nn.Parameter(
|
|
torch.log(torch.tensor([fixed_std**2 for _ in range(action_dim)])),
|
|
requires_grad=True,
|
|
)
|
|
self.logvar_min = torch.nn.Parameter(
|
|
torch.log(torch.tensor(std_min**2)), requires_grad=False
|
|
)
|
|
self.logvar_max = torch.nn.Parameter(
|
|
torch.log(torch.tensor(std_max**2)), requires_grad=False
|
|
)
|
|
self.use_fixed_std = fixed_std is not None
|
|
self.fixed_std = fixed_std
|
|
self.learn_fixed_std = learn_fixed_std
|
|
|
|
def forward(self, cond):
|
|
B = len(cond["rgb"])
|
|
device = cond["rgb"].device
|
|
_, T_rgb, C, H, W = cond["rgb"].shape
|
|
|
|
# flatten history
|
|
state = cond["state"].view(B, -1)
|
|
|
|
# Take recent images --- sometimes we want to use fewer img_cond_steps than cond_steps (e.g., 1 image but 3 prio)
|
|
rgb = cond["rgb"][:, -self.img_cond_steps :]
|
|
|
|
# concatenate images in cond by channels
|
|
if self.num_img > 1:
|
|
rgb = rgb.reshape(B, T_rgb, self.num_img, 3, H, W)
|
|
rgb = einops.rearrange(rgb, "b t n c h w -> b n (t c) h w")
|
|
else:
|
|
rgb = einops.rearrange(rgb, "b t c h w -> b (t c) h w")
|
|
|
|
# convert rgb to float32 for augmentation
|
|
rgb = rgb.float()
|
|
|
|
# get vit output - pass in two images separately
|
|
if self.num_img > 1: # TODO: properly handle multiple images
|
|
rgb1 = rgb[:, 0]
|
|
rgb2 = rgb[:, 1]
|
|
if self.augment:
|
|
rgb1 = self.aug(rgb1)
|
|
rgb2 = self.aug(rgb2)
|
|
feat1 = self.backbone(rgb1)
|
|
feat2 = self.backbone(rgb2)
|
|
feat1 = self.compress1.forward(feat1, state)
|
|
feat2 = self.compress2.forward(feat2, state)
|
|
feat = torch.cat([feat1, feat2], dim=-1)
|
|
else: # single image
|
|
if self.augment:
|
|
rgb = self.aug(rgb) # uint8 -> float32
|
|
feat = self.backbone(rgb)
|
|
|
|
# compress
|
|
if isinstance(self.compress, SpatialEmb):
|
|
feat = self.compress.forward(feat, state)
|
|
else:
|
|
feat = feat.flatten(1, -1)
|
|
feat = self.compress(feat)
|
|
|
|
# mlp
|
|
x_encoded = torch.cat([feat, state], dim=-1)
|
|
out_mean = self.mlp_mean(x_encoded)
|
|
out_mean = torch.tanh(out_mean).view(
|
|
B, self.horizon_steps * self.action_dim
|
|
) # tanh squashing in [-1, 1]
|
|
|
|
if self.learn_fixed_std:
|
|
out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max)
|
|
out_scale = torch.exp(0.5 * out_logvar)
|
|
out_scale = out_scale.view(1, self.action_dim)
|
|
out_scale = out_scale.repeat(B, self.horizon_steps)
|
|
elif self.use_fixed_std:
|
|
out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std
|
|
else:
|
|
out_logvar = self.mlp_logvar(x_encoded).view(
|
|
B, self.horizon_steps * self.action_dim
|
|
)
|
|
out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max)
|
|
out_scale = torch.exp(0.5 * out_logvar)
|
|
return out_mean, out_scale
|
|
|
|
|
|
class Gaussian_MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
action_dim,
|
|
horizon_steps,
|
|
cond_dim,
|
|
mlp_dims=[256, 256, 256],
|
|
activation_type="Mish",
|
|
tanh_output=True, # sometimes we want to apply tanh after sampling instead of here, e.g., in SAC
|
|
residual_style=False,
|
|
use_layernorm=False,
|
|
dropout=0.0,
|
|
fixed_std=None,
|
|
learn_fixed_std=False,
|
|
std_min=0.01,
|
|
std_max=1,
|
|
):
|
|
super().__init__()
|
|
self.action_dim = action_dim
|
|
self.horizon_steps = horizon_steps
|
|
input_dim = cond_dim
|
|
output_dim = action_dim * horizon_steps
|
|
if residual_style:
|
|
model = ResidualMLP
|
|
else:
|
|
model = MLP
|
|
if fixed_std is None:
|
|
# learning std
|
|
self.mlp_base = model(
|
|
[input_dim] + mlp_dims,
|
|
activation_type=activation_type,
|
|
out_activation_type=activation_type,
|
|
use_layernorm=use_layernorm,
|
|
use_layernorm_final=use_layernorm,
|
|
dropout=dropout,
|
|
)
|
|
self.mlp_mean = MLP(
|
|
mlp_dims[-1:] + [output_dim],
|
|
out_activation_type="Identity",
|
|
)
|
|
self.mlp_logvar = MLP(
|
|
mlp_dims[-1:] + [output_dim],
|
|
out_activation_type="Identity",
|
|
)
|
|
else:
|
|
# no separate head for mean and std
|
|
self.mlp_mean = model(
|
|
[input_dim] + mlp_dims + [output_dim],
|
|
activation_type=activation_type,
|
|
out_activation_type="Identity",
|
|
use_layernorm=use_layernorm,
|
|
dropout=dropout,
|
|
)
|
|
if learn_fixed_std:
|
|
# initialize to fixed_std
|
|
self.logvar = torch.nn.Parameter(
|
|
torch.log(torch.tensor([fixed_std**2 for _ in range(action_dim)])),
|
|
requires_grad=True,
|
|
)
|
|
self.logvar_min = torch.nn.Parameter(
|
|
torch.log(torch.tensor(std_min**2)), requires_grad=False
|
|
)
|
|
self.logvar_max = torch.nn.Parameter(
|
|
torch.log(torch.tensor(std_max**2)), requires_grad=False
|
|
)
|
|
self.use_fixed_std = fixed_std is not None
|
|
self.fixed_std = fixed_std
|
|
self.learn_fixed_std = learn_fixed_std
|
|
self.tanh_output = tanh_output
|
|
|
|
def forward(self, cond):
|
|
B = len(cond["state"])
|
|
device = cond["state"].device
|
|
|
|
# flatten history
|
|
state = cond["state"].view(B, -1)
|
|
|
|
# mlp
|
|
if hasattr(self, "mlp_base"):
|
|
state = self.mlp_base(state)
|
|
out_mean = self.mlp_mean(state)
|
|
if self.tanh_output:
|
|
out_mean = torch.tanh(out_mean)
|
|
out_mean = out_mean.view(B, self.horizon_steps * self.action_dim)
|
|
|
|
if self.learn_fixed_std:
|
|
out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max)
|
|
out_scale = torch.exp(0.5 * out_logvar)
|
|
out_scale = out_scale.view(1, self.action_dim)
|
|
out_scale = out_scale.repeat(B, self.horizon_steps)
|
|
elif self.use_fixed_std:
|
|
out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std
|
|
else:
|
|
out_logvar = self.mlp_logvar(state).view(
|
|
B, self.horizon_steps * self.action_dim
|
|
)
|
|
out_logvar = torch.tanh(out_logvar)
|
|
out_logvar = self.logvar_min + 0.5 * (self.logvar_max - self.logvar_min) * (
|
|
out_logvar + 1
|
|
) # put back to full range
|
|
out_scale = torch.exp(0.5 * out_logvar)
|
|
return out_mean, out_scale
|