* v0.5 (#9) * update idql configs * update awr configs * update dipo configs * update qsm configs * update dqm configs * update project version to 0.5.0
111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
"""
|
|
MLP models for GMM policy.
|
|
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from model.common.mlp import MLP, ResidualMLP
|
|
|
|
|
|
class GMM_MLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
action_dim,
|
|
horizon_steps,
|
|
cond_dim=None,
|
|
mlp_dims=[256, 256, 256],
|
|
num_modes=5,
|
|
activation_type="Mish",
|
|
residual_style=False,
|
|
use_layernorm=False,
|
|
fixed_std=None,
|
|
learn_fixed_std=False,
|
|
std_min=0.01,
|
|
std_max=1,
|
|
):
|
|
super().__init__()
|
|
self.action_dim = action_dim
|
|
self.horizon_steps = horizon_steps
|
|
input_dim = cond_dim
|
|
output_dim = action_dim * horizon_steps * num_modes
|
|
self.num_modes = num_modes
|
|
if residual_style:
|
|
model = ResidualMLP
|
|
else:
|
|
model = MLP
|
|
self.mlp_mean = model(
|
|
[input_dim] + mlp_dims + [output_dim],
|
|
activation_type=activation_type,
|
|
out_activation_type="Identity",
|
|
use_layernorm=use_layernorm,
|
|
)
|
|
if fixed_std is None:
|
|
self.mlp_logvar = model(
|
|
[input_dim] + mlp_dims + [output_dim],
|
|
activation_type=activation_type,
|
|
out_activation_type="Identity",
|
|
use_layernorm=use_layernorm,
|
|
)
|
|
elif (
|
|
learn_fixed_std
|
|
): # initialize to fixed_std, separate for each action and mode
|
|
self.logvar = torch.nn.Parameter(
|
|
torch.log(
|
|
torch.tensor(
|
|
[fixed_std**2 for _ in range(action_dim * num_modes)]
|
|
)
|
|
),
|
|
requires_grad=True,
|
|
)
|
|
self.logvar_min = torch.nn.Parameter(
|
|
torch.log(torch.tensor(std_min**2)), requires_grad=False
|
|
)
|
|
self.logvar_max = torch.nn.Parameter(
|
|
torch.log(torch.tensor(std_max**2)), requires_grad=False
|
|
)
|
|
self.use_fixed_std = fixed_std is not None
|
|
self.fixed_std = fixed_std
|
|
self.learn_fixed_std = learn_fixed_std
|
|
|
|
# mode weights
|
|
self.mlp_weights = model(
|
|
[input_dim] + mlp_dims + [num_modes],
|
|
activation_type=activation_type,
|
|
out_activation_type="Identity",
|
|
use_layernorm=use_layernorm,
|
|
)
|
|
|
|
def forward(self, cond):
|
|
B = len(cond["state"])
|
|
device = cond["state"].device
|
|
|
|
# flatten history
|
|
state = cond["state"].view(B, -1)
|
|
|
|
# mlp
|
|
out_mean = self.mlp_mean(state)
|
|
out_mean = torch.tanh(out_mean).view(
|
|
B, self.num_modes, self.horizon_steps * self.action_dim
|
|
) # tanh squashing in [-1, 1]
|
|
|
|
if self.learn_fixed_std:
|
|
out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max)
|
|
out_scale = torch.exp(0.5 * out_logvar)
|
|
out_scale = out_scale.view(1, self.num_modes, self.action_dim)
|
|
out_scale = out_scale.repeat(B, 1, self.horizon_steps)
|
|
elif self.use_fixed_std:
|
|
out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std
|
|
else:
|
|
out_logvar = self.mlp_logvar(state).view(
|
|
B, self.num_modes, self.horizon_steps * self.action_dim
|
|
)
|
|
out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max)
|
|
out_scale = torch.exp(0.5 * out_logvar)
|
|
|
|
out_weights = self.mlp_weights(state)
|
|
out_weights = out_weights.view(B, self.num_modes)
|
|
|
|
return out_mean, out_scale, out_weights
|