* v0.5 (#9) * update idql configs * update awr configs * update dipo configs * update qsm configs * update dqm configs * update project version to 0.5.0
207 lines
5.9 KiB
Python
207 lines
5.9 KiB
Python
"""
|
|
Critic networks.
|
|
|
|
"""
|
|
|
|
from typing import Union
|
|
import torch
|
|
import einops
|
|
from copy import deepcopy
|
|
|
|
from model.common.mlp import MLP, ResidualMLP
|
|
from model.common.modules import SpatialEmb, RandomShiftsAug
|
|
|
|
|
|
class CriticObs(torch.nn.Module):
|
|
"""State-only critic network."""
|
|
|
|
def __init__(
|
|
self,
|
|
cond_dim,
|
|
mlp_dims,
|
|
activation_type="Mish",
|
|
use_layernorm=False,
|
|
residual_style=False,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
mlp_dims = [cond_dim] + mlp_dims + [1]
|
|
if residual_style:
|
|
model = ResidualMLP
|
|
else:
|
|
model = MLP
|
|
self.Q1 = model(
|
|
mlp_dims,
|
|
activation_type=activation_type,
|
|
out_activation_type="Identity",
|
|
use_layernorm=use_layernorm,
|
|
)
|
|
|
|
def forward(self, cond: Union[dict, torch.Tensor]):
|
|
"""
|
|
cond: dict with key state/rgb; more recent obs at the end
|
|
state: (B, To, Do)
|
|
or (B, num_feature) from ViT encoder
|
|
"""
|
|
if isinstance(cond, dict):
|
|
B = len(cond["state"])
|
|
|
|
# flatten history
|
|
state = cond["state"].view(B, -1)
|
|
else:
|
|
state = cond
|
|
q1 = self.Q1(state)
|
|
return q1
|
|
|
|
|
|
class CriticObsAct(torch.nn.Module):
|
|
"""State-action double critic network."""
|
|
|
|
def __init__(
|
|
self,
|
|
cond_dim,
|
|
mlp_dims,
|
|
action_dim,
|
|
action_steps=1,
|
|
activation_type="Mish",
|
|
use_layernorm=False,
|
|
residual_tyle=False,
|
|
double_q=True,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
mlp_dims = [cond_dim + action_dim * action_steps] + mlp_dims + [1]
|
|
if residual_tyle:
|
|
model = ResidualMLP
|
|
else:
|
|
model = MLP
|
|
self.Q1 = model(
|
|
mlp_dims,
|
|
activation_type=activation_type,
|
|
out_activation_type="Identity",
|
|
use_layernorm=use_layernorm,
|
|
)
|
|
if double_q:
|
|
self.Q2 = model(
|
|
mlp_dims,
|
|
activation_type=activation_type,
|
|
out_activation_type="Identity",
|
|
use_layernorm=use_layernorm,
|
|
)
|
|
|
|
def forward(self, cond: dict, action):
|
|
"""
|
|
cond: dict with key state/rgb; more recent obs at the end
|
|
state: (B, To, Do)
|
|
action: (B, Ta, Da)
|
|
"""
|
|
B = len(cond["state"])
|
|
|
|
# flatten history
|
|
state = cond["state"].view(B, -1)
|
|
|
|
# flatten action
|
|
action = action.view(B, -1)
|
|
|
|
x = torch.cat((state, action), dim=-1)
|
|
if hasattr(self, "Q2"):
|
|
q1 = self.Q1(x)
|
|
q2 = self.Q2(x)
|
|
return q1.squeeze(1), q2.squeeze(1)
|
|
else:
|
|
q1 = self.Q1(x)
|
|
return q1.squeeze(1)
|
|
|
|
|
|
class ViTCritic(CriticObs):
|
|
"""ViT + MLP, state only"""
|
|
|
|
def __init__(
|
|
self,
|
|
backbone,
|
|
cond_dim,
|
|
img_cond_steps=1,
|
|
spatial_emb=128,
|
|
dropout=0,
|
|
augment=False,
|
|
num_img=1,
|
|
**kwargs,
|
|
):
|
|
# update input dim to mlp
|
|
mlp_obs_dim = spatial_emb * num_img + cond_dim
|
|
super().__init__(cond_dim=mlp_obs_dim, **kwargs)
|
|
self.backbone = backbone
|
|
self.num_img = num_img
|
|
self.img_cond_steps = img_cond_steps
|
|
if num_img > 1:
|
|
self.compress1 = SpatialEmb(
|
|
num_patch=self.backbone.num_patch,
|
|
patch_dim=self.backbone.patch_repr_dim,
|
|
prop_dim=cond_dim,
|
|
proj_dim=spatial_emb,
|
|
dropout=dropout,
|
|
)
|
|
self.compress2 = deepcopy(self.compress1)
|
|
else: # TODO: clean up
|
|
self.compress = SpatialEmb(
|
|
num_patch=self.backbone.num_patch,
|
|
patch_dim=self.backbone.patch_repr_dim,
|
|
prop_dim=cond_dim,
|
|
proj_dim=spatial_emb,
|
|
dropout=dropout,
|
|
)
|
|
if augment:
|
|
self.aug = RandomShiftsAug(pad=4)
|
|
self.augment = augment
|
|
|
|
def forward(
|
|
self,
|
|
cond: dict,
|
|
no_augment=False,
|
|
):
|
|
"""
|
|
cond: dict with key state/rgb; more recent obs at the end
|
|
state: (B, To, Do)
|
|
rgb: (B, To, C, H, W)
|
|
no_augment: whether to skip augmentation
|
|
|
|
TODO long term: more flexible handling of cond
|
|
"""
|
|
B, T_rgb, C, H, W = cond["rgb"].shape
|
|
|
|
# flatten history
|
|
state = cond["state"].view(B, -1)
|
|
|
|
# Take recent images --- sometimes we want to use fewer img_cond_steps than cond_steps (e.g., 1 image but 3 prio)
|
|
rgb = cond["rgb"][:, -self.img_cond_steps :]
|
|
|
|
# concatenate images in cond by channels
|
|
if self.num_img > 1:
|
|
rgb = rgb.reshape(B, T_rgb, self.num_img, 3, H, W)
|
|
rgb = einops.rearrange(rgb, "b t n c h w -> b n (t c) h w")
|
|
else:
|
|
rgb = einops.rearrange(rgb, "b t c h w -> b (t c) h w")
|
|
|
|
# convert rgb to float32 for augmentation
|
|
rgb = rgb.float()
|
|
|
|
# get vit output - pass in two images separately
|
|
if self.num_img > 1: # TODO: properly handle multiple images
|
|
rgb1 = rgb[:, 0]
|
|
rgb2 = rgb[:, 1]
|
|
if self.augment and not no_augment:
|
|
rgb1 = self.aug(rgb1)
|
|
rgb2 = self.aug(rgb2)
|
|
feat1 = self.backbone(rgb1)
|
|
feat2 = self.backbone(rgb2)
|
|
feat1 = self.compress1.forward(feat1, state)
|
|
feat2 = self.compress2.forward(feat2, state)
|
|
feat = torch.cat([feat1, feat2], dim=-1)
|
|
else: # single image
|
|
if self.augment and not no_augment:
|
|
rgb = self.aug(rgb) # uint8 -> float32
|
|
feat = self.backbone(rgb)
|
|
feat = self.compress.forward(feat, state)
|
|
feat = torch.cat([feat, state], dim=-1)
|
|
return super().forward(feat)
|