""" MLP models for Gaussian policy. """ import torch import torch.nn as nn import einops from copy import deepcopy from model.common.mlp import MLP, ResidualMLP from model.common.modules import SpatialEmb, RandomShiftsAug class Gaussian_VisionMLP(nn.Module): """With ViT backbone""" def __init__( self, backbone, transition_dim, horizon_steps, cond_dim, img_cond_steps=1, mlp_dims=[256, 256, 256], activation_type="Mish", residual_style=False, use_layernorm=False, fixed_std=None, learn_fixed_std=False, std_min=0.01, std_max=1, spatial_emb=0, visual_feature_dim=128, dropout=0, num_img=1, augment=False, ): super().__init__() # vision self.backbone = backbone if augment: self.aug = RandomShiftsAug(pad=4) self.augment = augment self.num_img = num_img self.img_cond_steps = img_cond_steps if spatial_emb > 0: assert spatial_emb > 1, "this is the dimension" if num_img > 1: self.compress1 = SpatialEmb( num_patch=self.backbone.num_patch, patch_dim=self.backbone.patch_repr_dim, prop_dim=cond_dim, proj_dim=spatial_emb, dropout=dropout, ) self.compress2 = deepcopy(self.compress1) else: # TODO: clean up self.compress = SpatialEmb( num_patch=self.backbone.num_patch, patch_dim=self.backbone.patch_repr_dim, prop_dim=cond_dim, proj_dim=spatial_emb, dropout=dropout, ) visual_feature_dim = spatial_emb * num_img else: self.compress = nn.Sequential( nn.Linear(self.backbone.repr_dim, visual_feature_dim), nn.LayerNorm(visual_feature_dim), nn.Dropout(dropout), nn.ReLU(), ) # head self.transition_dim = transition_dim self.horizon_steps = horizon_steps input_dim = visual_feature_dim + cond_dim output_dim = transition_dim * horizon_steps if residual_style: model = ResidualMLP else: model = MLP self.mlp_mean = model( [input_dim] + mlp_dims + [output_dim], activation_type=activation_type, out_activation_type="Identity", use_layernorm=use_layernorm, ) if fixed_std is None: self.mlp_logvar = MLP( [input_dim] + mlp_dims[-1:] + [output_dim], activation_type=activation_type, out_activation_type="Identity", use_layernorm=use_layernorm, ) elif learn_fixed_std: # initialize to fixed_std self.logvar = torch.nn.Parameter( torch.log(torch.tensor([fixed_std**2 for _ in range(transition_dim)])), requires_grad=True, ) self.logvar_min = torch.nn.Parameter( torch.log(torch.tensor(std_min**2)), requires_grad=False ) self.logvar_max = torch.nn.Parameter( torch.log(torch.tensor(std_max**2)), requires_grad=False ) self.use_fixed_std = fixed_std is not None self.fixed_std = fixed_std self.learn_fixed_std = learn_fixed_std def forward(self, cond): B = len(cond["rgb"]) device = cond["rgb"].device _, T_rgb, C, H, W = cond["rgb"].shape # flatten history state = cond["state"].view(B, -1) # Take recent images --- sometimes we want to use fewer img_cond_steps than cond_steps (e.g., 1 image but 3 prio) rgb = cond["rgb"][:, -self.img_cond_steps :] # concatenate images in cond by channels if self.num_img > 1: rgb = rgb.reshape(B, T_rgb, self.num_img, 3, H, W) rgb = einops.rearrange(rgb, "b t n c h w -> b n (t c) h w") else: rgb = einops.rearrange(rgb, "b t c h w -> b (t c) h w") # convert rgb to float32 for augmentation rgb = rgb.float() # get vit output - pass in two images separately if self.num_img > 1: # TODO: properly handle multiple images rgb1 = rgb[:, 0] rgb2 = rgb[:, 1] if self.augment: rgb1 = self.aug(rgb1) rgb2 = self.aug(rgb2) feat1 = self.backbone(rgb1) feat2 = self.backbone(rgb2) feat1 = self.compress1.forward(feat1, state) feat2 = self.compress2.forward(feat2, state) feat = torch.cat([feat1, feat2], dim=-1) else: # single image if self.augment: rgb = self.aug(rgb) # uint8 -> float32 feat = self.backbone(rgb) # compress if isinstance(self.compress, SpatialEmb): feat = self.compress.forward(feat, state) else: feat = feat.flatten(1, -1) feat = self.compress(feat) # mlp x_encoded = torch.cat([feat, state], dim=-1) out_mean = self.mlp_mean(x_encoded) out_mean = torch.tanh(out_mean).view( B, self.horizon_steps * self.transition_dim ) # tanh squashing in [-1, 1] if self.learn_fixed_std: out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max) out_scale = torch.exp(0.5 * out_logvar) out_scale = out_scale.view(1, self.transition_dim) out_scale = out_scale.repeat(B, self.horizon_steps) elif self.use_fixed_std: out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std else: out_logvar = self.mlp_logvar(x_encoded).view( B, self.horizon_steps * self.transition_dim ) out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max) out_scale = torch.exp(0.5 * out_logvar) return out_mean, out_scale class Gaussian_MLP(nn.Module): def __init__( self, transition_dim, horizon_steps, cond_dim, mlp_dims=[256, 256, 256], activation_type="Mish", residual_style=False, use_layernorm=False, fixed_std=None, learn_fixed_std=False, std_min=0.01, std_max=1, ): super().__init__() self.transition_dim = transition_dim self.horizon_steps = horizon_steps input_dim = cond_dim output_dim = transition_dim * horizon_steps if residual_style: model = ResidualMLP else: model = MLP self.mlp_mean = model( [input_dim] + mlp_dims + [output_dim], activation_type=activation_type, out_activation_type="Identity", use_layernorm=use_layernorm, ) if fixed_std is None: self.mlp_logvar = MLP( [input_dim] + mlp_dims[-1:] + [output_dim], activation_type=activation_type, out_activation_type="Identity", use_layernorm=use_layernorm, ) elif learn_fixed_std: # initialize to fixed_std self.logvar = torch.nn.Parameter( torch.log(torch.tensor([fixed_std**2 for _ in range(transition_dim)])), requires_grad=True, ) self.logvar_min = torch.nn.Parameter( torch.log(torch.tensor(std_min**2)), requires_grad=False ) self.logvar_max = torch.nn.Parameter( torch.log(torch.tensor(std_max**2)), requires_grad=False ) self.use_fixed_std = fixed_std is not None self.fixed_std = fixed_std self.learn_fixed_std = learn_fixed_std def forward(self, cond): B = len(cond["state"]) device = cond["state"].device # flatten history state = cond["state"].view(B, -1) # mlp out_mean = self.mlp_mean(state) out_mean = torch.tanh(out_mean).view( B, self.horizon_steps * self.transition_dim ) # tanh squashing in [-1, 1] if self.learn_fixed_std: out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max) out_scale = torch.exp(0.5 * out_logvar) out_scale = out_scale.view(1, self.transition_dim) out_scale = out_scale.repeat(B, self.horizon_steps) elif self.use_fixed_std: out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std else: out_logvar = self.mlp_logvar(state).view( B, self.horizon_steps * self.transition_dim ) out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max) out_scale = torch.exp(0.5 * out_logvar) return out_mean, out_scale