dppo/model/common/mlp.py
Allen Z. Ren e0842e71dc
v0.5 to main (#10)
* v0.5 (#9)

* update idql configs

* update awr configs

* update dipo configs

* update qsm configs

* update dqm configs

* update project version to 0.5.0
2024-10-07 16:35:13 -04:00

150 lines
4.6 KiB
Python

"""
Implementation of Multi-layer Perception (MLP).
Residual model is taken from https://github.com/ALRhub/d3il/blob/main/agents/models/common/mlp.py
"""
import torch
from torch import nn
from collections import OrderedDict
import logging
activation_dict = nn.ModuleDict(
{
"ReLU": nn.ReLU(),
"ELU": nn.ELU(),
"GELU": nn.GELU(),
"Tanh": nn.Tanh(),
"Mish": nn.Mish(),
"Identity": nn.Identity(),
"Softplus": nn.Softplus(),
}
)
class MLP(nn.Module):
def __init__(
self,
dim_list,
append_dim=0,
append_layers=None,
activation_type="Tanh",
out_activation_type="Identity",
use_layernorm=False,
use_layernorm_final=False,
dropout=0,
use_drop_final=False,
verbose=False,
):
super(MLP, self).__init__()
# Construct module list: if use `Python List`, the modules are not
# added to computation graph. Instead, we should use `nn.ModuleList()`.
self.moduleList = nn.ModuleList()
self.append_layers = append_layers
num_layer = len(dim_list) - 1
for idx in range(num_layer):
i_dim = dim_list[idx]
o_dim = dim_list[idx + 1]
if append_dim > 0 and idx in append_layers:
i_dim += append_dim
linear_layer = nn.Linear(i_dim, o_dim)
# Add module components
layers = [("linear_1", linear_layer)]
if use_layernorm and (idx < num_layer - 1 or use_layernorm_final):
layers.append(("norm_1", nn.LayerNorm(o_dim)))
if dropout > 0 and (idx < num_layer - 1 or use_drop_final):
layers.append(("dropout_1", nn.Dropout(dropout)))
# add activation function
act = (
activation_dict[activation_type]
if idx != num_layer - 1
else activation_dict[out_activation_type]
)
layers.append(("act_1", act))
# re-construct module
module = nn.Sequential(OrderedDict(layers))
self.moduleList.append(module)
if verbose:
logging.info(self.moduleList)
def forward(self, x, append=None):
for layer_ind, m in enumerate(self.moduleList):
if append is not None and layer_ind in self.append_layers:
x = torch.cat((x, append), dim=-1)
x = m(x)
return x
class ResidualMLP(nn.Module):
"""
Simple multi layer perceptron network with residual connections for
benchmarking the performance of different networks. The resiudal layers
are based on the IBC paper implementation, which uses 2 residual lalyers
with pre-actication with or without dropout and normalization.
"""
def __init__(
self,
dim_list,
activation_type="Mish",
out_activation_type="Identity",
use_layernorm=False,
use_layernorm_final=False,
):
super(ResidualMLP, self).__init__()
hidden_dim = dim_list[1]
num_hidden_layers = len(dim_list) - 3
assert num_hidden_layers % 2 == 0
self.layers = nn.ModuleList([nn.Linear(dim_list[0], hidden_dim)])
self.layers.extend(
[
TwoLayerPreActivationResNetLinear(
hidden_dim=hidden_dim,
activation_type=activation_type,
use_layernorm=use_layernorm,
)
for _ in range(1, num_hidden_layers, 2)
]
)
self.layers.append(nn.Linear(hidden_dim, dim_list[-1]))
if use_layernorm_final:
self.layers.append(nn.LayerNorm(dim_list[-1]))
self.layers.append(activation_dict[out_activation_type])
def forward(self, x):
for _, layer in enumerate(self.layers):
x = layer(x)
return x
class TwoLayerPreActivationResNetLinear(nn.Module):
def __init__(
self,
hidden_dim,
activation_type="Mish",
use_layernorm=False,
):
super().__init__()
self.l1 = nn.Linear(hidden_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.act = activation_dict[activation_type]
if use_layernorm:
self.norm1 = nn.LayerNorm(hidden_dim, eps=1e-06)
self.norm2 = nn.LayerNorm(hidden_dim, eps=1e-06)
def forward(self, x):
x_input = x
if hasattr(self, "norm1"):
x = self.norm1(x)
x = self.l1(self.act(x))
if hasattr(self, "norm2"):
x = self.norm2(x)
x = self.l2(self.act(x))
return x + x_input