dppo/model/common/mlp_gmm.py
Allen Z. Ren e0842e71dc
v0.5 to main (#10)
* v0.5 (#9)

* update idql configs

* update awr configs

* update dipo configs

* update qsm configs

* update dqm configs

* update project version to 0.5.0
2024-10-07 16:35:13 -04:00

111 lines
3.5 KiB
Python

"""
MLP models for GMM policy.
"""
import torch
import torch.nn as nn
from model.common.mlp import MLP, ResidualMLP
class GMM_MLP(nn.Module):
def __init__(
self,
action_dim,
horizon_steps,
cond_dim=None,
mlp_dims=[256, 256, 256],
num_modes=5,
activation_type="Mish",
residual_style=False,
use_layernorm=False,
fixed_std=None,
learn_fixed_std=False,
std_min=0.01,
std_max=1,
):
super().__init__()
self.action_dim = action_dim
self.horizon_steps = horizon_steps
input_dim = cond_dim
output_dim = action_dim * horizon_steps * num_modes
self.num_modes = num_modes
if residual_style:
model = ResidualMLP
else:
model = MLP
self.mlp_mean = model(
[input_dim] + mlp_dims + [output_dim],
activation_type=activation_type,
out_activation_type="Identity",
use_layernorm=use_layernorm,
)
if fixed_std is None:
self.mlp_logvar = model(
[input_dim] + mlp_dims + [output_dim],
activation_type=activation_type,
out_activation_type="Identity",
use_layernorm=use_layernorm,
)
elif (
learn_fixed_std
): # initialize to fixed_std, separate for each action and mode
self.logvar = torch.nn.Parameter(
torch.log(
torch.tensor(
[fixed_std**2 for _ in range(action_dim * num_modes)]
)
),
requires_grad=True,
)
self.logvar_min = torch.nn.Parameter(
torch.log(torch.tensor(std_min**2)), requires_grad=False
)
self.logvar_max = torch.nn.Parameter(
torch.log(torch.tensor(std_max**2)), requires_grad=False
)
self.use_fixed_std = fixed_std is not None
self.fixed_std = fixed_std
self.learn_fixed_std = learn_fixed_std
# mode weights
self.mlp_weights = model(
[input_dim] + mlp_dims + [num_modes],
activation_type=activation_type,
out_activation_type="Identity",
use_layernorm=use_layernorm,
)
def forward(self, cond):
B = len(cond["state"])
device = cond["state"].device
# flatten history
state = cond["state"].view(B, -1)
# mlp
out_mean = self.mlp_mean(state)
out_mean = torch.tanh(out_mean).view(
B, self.num_modes, self.horizon_steps * self.action_dim
) # tanh squashing in [-1, 1]
if self.learn_fixed_std:
out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max)
out_scale = torch.exp(0.5 * out_logvar)
out_scale = out_scale.view(1, self.num_modes, self.action_dim)
out_scale = out_scale.repeat(B, 1, self.horizon_steps)
elif self.use_fixed_std:
out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std
else:
out_logvar = self.mlp_logvar(state).view(
B, self.num_modes, self.horizon_steps * self.action_dim
)
out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max)
out_scale = torch.exp(0.5 * out_logvar)
out_weights = self.mlp_weights(state)
out_weights = out_weights.view(B, self.num_modes)
return out_mean, out_scale, out_weights