* v0.5 (#9) * update idql configs * update awr configs * update dipo configs * update qsm configs * update dqm configs * update project version to 0.5.0
150 lines
4.6 KiB
Python
150 lines
4.6 KiB
Python
"""
|
|
Implementation of Multi-layer Perception (MLP).
|
|
|
|
Residual model is taken from https://github.com/ALRhub/d3il/blob/main/agents/models/common/mlp.py
|
|
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
from collections import OrderedDict
|
|
import logging
|
|
|
|
|
|
activation_dict = nn.ModuleDict(
|
|
{
|
|
"ReLU": nn.ReLU(),
|
|
"ELU": nn.ELU(),
|
|
"GELU": nn.GELU(),
|
|
"Tanh": nn.Tanh(),
|
|
"Mish": nn.Mish(),
|
|
"Identity": nn.Identity(),
|
|
"Softplus": nn.Softplus(),
|
|
}
|
|
)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim_list,
|
|
append_dim=0,
|
|
append_layers=None,
|
|
activation_type="Tanh",
|
|
out_activation_type="Identity",
|
|
use_layernorm=False,
|
|
use_layernorm_final=False,
|
|
dropout=0,
|
|
use_drop_final=False,
|
|
verbose=False,
|
|
):
|
|
super(MLP, self).__init__()
|
|
|
|
# Construct module list: if use `Python List`, the modules are not
|
|
# added to computation graph. Instead, we should use `nn.ModuleList()`.
|
|
self.moduleList = nn.ModuleList()
|
|
self.append_layers = append_layers
|
|
num_layer = len(dim_list) - 1
|
|
for idx in range(num_layer):
|
|
i_dim = dim_list[idx]
|
|
o_dim = dim_list[idx + 1]
|
|
if append_dim > 0 and idx in append_layers:
|
|
i_dim += append_dim
|
|
linear_layer = nn.Linear(i_dim, o_dim)
|
|
|
|
# Add module components
|
|
layers = [("linear_1", linear_layer)]
|
|
if use_layernorm and (idx < num_layer - 1 or use_layernorm_final):
|
|
layers.append(("norm_1", nn.LayerNorm(o_dim)))
|
|
if dropout > 0 and (idx < num_layer - 1 or use_drop_final):
|
|
layers.append(("dropout_1", nn.Dropout(dropout)))
|
|
|
|
# add activation function
|
|
act = (
|
|
activation_dict[activation_type]
|
|
if idx != num_layer - 1
|
|
else activation_dict[out_activation_type]
|
|
)
|
|
layers.append(("act_1", act))
|
|
|
|
# re-construct module
|
|
module = nn.Sequential(OrderedDict(layers))
|
|
self.moduleList.append(module)
|
|
if verbose:
|
|
logging.info(self.moduleList)
|
|
|
|
def forward(self, x, append=None):
|
|
for layer_ind, m in enumerate(self.moduleList):
|
|
if append is not None and layer_ind in self.append_layers:
|
|
x = torch.cat((x, append), dim=-1)
|
|
x = m(x)
|
|
return x
|
|
|
|
|
|
class ResidualMLP(nn.Module):
|
|
"""
|
|
Simple multi layer perceptron network with residual connections for
|
|
benchmarking the performance of different networks. The resiudal layers
|
|
are based on the IBC paper implementation, which uses 2 residual lalyers
|
|
with pre-actication with or without dropout and normalization.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim_list,
|
|
activation_type="Mish",
|
|
out_activation_type="Identity",
|
|
use_layernorm=False,
|
|
use_layernorm_final=False,
|
|
):
|
|
super(ResidualMLP, self).__init__()
|
|
hidden_dim = dim_list[1]
|
|
num_hidden_layers = len(dim_list) - 3
|
|
assert num_hidden_layers % 2 == 0
|
|
self.layers = nn.ModuleList([nn.Linear(dim_list[0], hidden_dim)])
|
|
self.layers.extend(
|
|
[
|
|
TwoLayerPreActivationResNetLinear(
|
|
hidden_dim=hidden_dim,
|
|
activation_type=activation_type,
|
|
use_layernorm=use_layernorm,
|
|
)
|
|
for _ in range(1, num_hidden_layers, 2)
|
|
]
|
|
)
|
|
self.layers.append(nn.Linear(hidden_dim, dim_list[-1]))
|
|
if use_layernorm_final:
|
|
self.layers.append(nn.LayerNorm(dim_list[-1]))
|
|
self.layers.append(activation_dict[out_activation_type])
|
|
|
|
def forward(self, x):
|
|
for _, layer in enumerate(self.layers):
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class TwoLayerPreActivationResNetLinear(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_dim,
|
|
activation_type="Mish",
|
|
use_layernorm=False,
|
|
):
|
|
super().__init__()
|
|
self.l1 = nn.Linear(hidden_dim, hidden_dim)
|
|
self.l2 = nn.Linear(hidden_dim, hidden_dim)
|
|
self.act = activation_dict[activation_type]
|
|
if use_layernorm:
|
|
self.norm1 = nn.LayerNorm(hidden_dim, eps=1e-06)
|
|
self.norm2 = nn.LayerNorm(hidden_dim, eps=1e-06)
|
|
|
|
def forward(self, x):
|
|
x_input = x
|
|
if hasattr(self, "norm1"):
|
|
x = self.norm1(x)
|
|
x = self.l1(self.act(x))
|
|
if hasattr(self, "norm2"):
|
|
x = self.norm2(x)
|
|
x = self.l2(self.act(x))
|
|
return x + x_input
|