dppo/model/common/mlp.py
Allen Z. Ren dc8e0c9edc
v0.6 (#18)
* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* Add Proficient Human (PH) Configs and Pipeline (#16)

* fix missing cfg

* add ph config

* fix how terminated flags are added to buffer in ibrl

* add ph config

* offline calql for 1M gradient updates

* bug fix: number of calql online gradient steps is the number of new transitions collected

* add sample config for DPPO with ta=1

* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* fix diffusion loss when predicting initial noise

* fix dppo inds

* fix typo

* remove print statement

---------

Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
Co-authored-by: allenzren <allen.ren@princeton.edu>

* update robomimic configs

* better calql formulation

* optimize calql and ibrl training

* optimize data transfer in ppo agents

* add kitchen configs

* re-organize config folders, rerun calql and rlpd

* add scratch gym locomotion configs

* add kitchen installation dependencies

* use truncated for termination in furniture env

* update furniture and gym configs

* update README and dependencies with kitchen

* add url for new data and checkpoints

* update demo RL configs

* update batch sizes for furniture unet configs

* raise error about dropout in residual mlp

* fix observation bug in bc loss

---------

Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com>
Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
2024-10-30 19:58:06 -04:00

155 lines
4.8 KiB
Python

"""
Implementation of Multi-layer Perception (MLP).
Residual model is taken from https://github.com/ALRhub/d3il/blob/main/agents/models/common/mlp.py
"""
import torch
from torch import nn
from collections import OrderedDict
import logging
activation_dict = nn.ModuleDict(
{
"ReLU": nn.ReLU(),
"ELU": nn.ELU(),
"GELU": nn.GELU(),
"Tanh": nn.Tanh(),
"Mish": nn.Mish(),
"Identity": nn.Identity(),
"Softplus": nn.Softplus(),
}
)
class MLP(nn.Module):
def __init__(
self,
dim_list,
append_dim=0,
append_layers=None,
activation_type="Tanh",
out_activation_type="Identity",
use_layernorm=False,
use_layernorm_final=False,
dropout=0,
use_drop_final=False,
verbose=False,
):
super(MLP, self).__init__()
# Construct module list: if use `Python List`, the modules are not
# added to computation graph. Instead, we should use `nn.ModuleList()`.
self.moduleList = nn.ModuleList()
self.append_layers = append_layers
num_layer = len(dim_list) - 1
for idx in range(num_layer):
i_dim = dim_list[idx]
o_dim = dim_list[idx + 1]
if append_dim > 0 and idx in append_layers:
i_dim += append_dim
linear_layer = nn.Linear(i_dim, o_dim)
# Add module components
layers = [("linear_1", linear_layer)]
if use_layernorm and (idx < num_layer - 1 or use_layernorm_final):
layers.append(("norm_1", nn.LayerNorm(o_dim)))
if dropout > 0 and (idx < num_layer - 1 or use_drop_final):
layers.append(("dropout_1", nn.Dropout(dropout)))
# add activation function
act = (
activation_dict[activation_type]
if idx != num_layer - 1
else activation_dict[out_activation_type]
)
layers.append(("act_1", act))
# re-construct module
module = nn.Sequential(OrderedDict(layers))
self.moduleList.append(module)
if verbose:
logging.info(self.moduleList)
def forward(self, x, append=None):
for layer_ind, m in enumerate(self.moduleList):
if append is not None and layer_ind in self.append_layers:
x = torch.cat((x, append), dim=-1)
x = m(x)
return x
class ResidualMLP(nn.Module):
"""
Simple multi layer perceptron network with residual connections for
benchmarking the performance of different networks. The resiudal layers
are based on the IBC paper implementation, which uses 2 residual lalyers
with pre-actication with or without dropout and normalization.
"""
def __init__(
self,
dim_list,
activation_type="Mish",
out_activation_type="Identity",
use_layernorm=False,
use_layernorm_final=False,
dropout=0,
):
super(ResidualMLP, self).__init__()
hidden_dim = dim_list[1]
num_hidden_layers = len(dim_list) - 3
assert num_hidden_layers % 2 == 0
self.layers = nn.ModuleList([nn.Linear(dim_list[0], hidden_dim)])
self.layers.extend(
[
TwoLayerPreActivationResNetLinear(
hidden_dim=hidden_dim,
activation_type=activation_type,
use_layernorm=use_layernorm,
dropout=dropout,
)
for _ in range(1, num_hidden_layers, 2)
]
)
self.layers.append(nn.Linear(hidden_dim, dim_list[-1]))
if use_layernorm_final:
self.layers.append(nn.LayerNorm(dim_list[-1]))
self.layers.append(activation_dict[out_activation_type])
def forward(self, x):
for _, layer in enumerate(self.layers):
x = layer(x)
return x
class TwoLayerPreActivationResNetLinear(nn.Module):
def __init__(
self,
hidden_dim,
activation_type="Mish",
use_layernorm=False,
dropout=0,
):
super().__init__()
self.l1 = nn.Linear(hidden_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.act = activation_dict[activation_type]
if use_layernorm:
self.norm1 = nn.LayerNorm(hidden_dim, eps=1e-06)
self.norm2 = nn.LayerNorm(hidden_dim, eps=1e-06)
if dropout > 0:
raise NotImplementedError("Dropout not implemented for residual MLP!")
def forward(self, x):
x_input = x
if hasattr(self, "norm1"):
x = self.norm1(x)
x = self.l1(self.act(x))
if hasattr(self, "norm2"):
x = self.norm2(x)
x = self.l2(self.act(x))
return x + x_input