dppo/model/common/mlp.py
2024-09-03 21:03:27 -04:00

161 lines
4.9 KiB
Python

"""
Implementation of Multi-layer Perception (MLP).
Residual model is taken from https://github.com/ALRhub/d3il/blob/main/agents/models/common/mlp.py
"""
import torch
from torch import nn
from torch.nn.utils import spectral_norm
from collections import OrderedDict
import logging
activation_dict = nn.ModuleDict(
{
"ReLU": nn.ReLU(),
"ELU": nn.ELU(),
"GELU": nn.GELU(),
"Tanh": nn.Tanh(),
"Mish": nn.Mish(),
"Identity": nn.Identity(),
"Softplus": nn.Softplus(),
}
)
class MLP(nn.Module):
def __init__(
self,
dim_list,
append_dim=0,
append_layers=None,
activation_type="Tanh",
out_activation_type="Identity",
use_layernorm=False,
use_spectralnorm=False,
verbose=False,
):
super(MLP, self).__init__()
# Construct module list: if use `Python List`, the modules are not
# added to computation graph. Instead, we should use `nn.ModuleList()`.
self.moduleList = nn.ModuleList()
self.append_layers = append_layers
num_layer = len(dim_list) - 1
for idx in range(num_layer):
i_dim = dim_list[idx]
o_dim = dim_list[idx + 1]
if append_dim > 0 and idx in append_layers:
i_dim += append_dim
linear_layer = nn.Linear(i_dim, o_dim)
if use_spectralnorm:
linear_layer = spectral_norm(linear_layer)
if idx == num_layer - 1:
module = nn.Sequential(
OrderedDict(
[
("linear_1", linear_layer),
("act_1", activation_dict[out_activation_type]),
]
)
)
else:
if use_layernorm:
module = nn.Sequential(
OrderedDict(
[
("linear_1", linear_layer),
("norm_1", nn.LayerNorm(o_dim)),
("act_1", activation_dict[activation_type]),
]
)
)
else:
module = nn.Sequential(
OrderedDict(
[
("linear_1", linear_layer),
("act_1", activation_dict[activation_type]),
]
)
)
self.moduleList.append(module)
if verbose:
logging.info(self.moduleList)
def forward(self, x, append=None):
for layer_ind, m in enumerate(self.moduleList):
if append is not None and layer_ind in self.append_layers:
x = torch.cat((x, append), dim=-1)
x = m(x)
return x
class ResidualMLP(nn.Module):
"""
Simple multi layer perceptron network with residual connections for
benchmarking the performance of different networks. The resiudal layers
are based on the IBC paper implementation, which uses 2 residual lalyers
with pre-actication with or without dropout and normalization.
"""
def __init__(
self,
dim_list,
activation_type="Mish",
out_activation_type="Identity",
use_layernorm=False,
):
super(ResidualMLP, self).__init__()
hidden_dim = dim_list[1]
num_hidden_layers = len(dim_list) - 3
assert num_hidden_layers % 2 == 0
self.layers = nn.ModuleList([nn.Linear(dim_list[0], hidden_dim)])
self.layers.extend(
[
TwoLayerPreActivationResNetLinear(
hidden_dim=hidden_dim,
activation_type=activation_type,
use_layernorm=use_layernorm,
)
for _ in range(1, num_hidden_layers, 2)
]
)
self.layers.append(nn.Linear(hidden_dim, dim_list[-1]))
self.layers.append(activation_dict[out_activation_type])
def forward(self, x):
for _, layer in enumerate(self.layers):
x = layer(x)
return x
class TwoLayerPreActivationResNetLinear(nn.Module):
def __init__(
self,
hidden_dim,
activation_type="Mish",
use_layernorm=False,
):
super().__init__()
self.l1 = nn.Linear(hidden_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.act = activation_dict[activation_type]
if use_layernorm:
self.norm1 = nn.LayerNorm(hidden_dim, eps=1e-06)
self.norm2 = nn.LayerNorm(hidden_dim, eps=1e-06)
def forward(self, x):
x_input = x
if hasattr(self, "norm1"):
x = self.norm1(x)
x = self.l1(self.act(x))
if hasattr(self, "norm2"):
x = self.norm2(x)
x = self.l2(self.act(x))
return x + x_input