dppo/model/diffusion/mlp_diffusion.py
2024-09-16 17:55:31 -04:00

251 lines
7.4 KiB
Python

"""
MLP models for diffusion policies.
"""
import torch
import torch.nn as nn
import logging
import einops
from copy import deepcopy
from model.common.mlp import MLP, ResidualMLP
from model.diffusion.modules import SinusoidalPosEmb
from model.common.modules import SpatialEmb, RandomShiftsAug
log = logging.getLogger(__name__)
class VisionDiffusionMLP(nn.Module):
"""With ViT backbone"""
def __init__(
self,
backbone,
transition_dim,
horizon_steps,
cond_dim,
img_cond_steps=1,
time_dim=16,
mlp_dims=[256, 256],
activation_type="Mish",
out_activation_type="Identity",
use_layernorm=False,
residual_style=False,
spatial_emb=0,
visual_feature_dim=128,
dropout=0,
num_img=1,
augment=False,
):
super().__init__()
# vision
self.backbone = backbone
if augment:
self.aug = RandomShiftsAug(pad=4)
self.augment = augment
self.num_img = num_img
self.img_cond_steps = img_cond_steps
if spatial_emb > 0:
assert spatial_emb > 1, "this is the dimension"
if num_img > 1:
self.compress1 = SpatialEmb(
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,
)
self.compress2 = deepcopy(self.compress1)
else: # TODO: clean up
self.compress = SpatialEmb(
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,
)
visual_feature_dim = spatial_emb * num_img
else:
self.compress = nn.Sequential(
nn.Linear(self.backbone.repr_dim, visual_feature_dim),
nn.LayerNorm(visual_feature_dim),
nn.Dropout(dropout),
nn.ReLU(),
)
# diffusion
input_dim = (
time_dim + transition_dim * horizon_steps + visual_feature_dim + cond_dim
)
output_dim = transition_dim * horizon_steps
self.time_embedding = nn.Sequential(
SinusoidalPosEmb(time_dim),
nn.Linear(time_dim, time_dim * 2),
nn.Mish(),
nn.Linear(time_dim * 2, time_dim),
)
if residual_style:
model = ResidualMLP
else:
model = MLP
self.mlp_mean = model(
[input_dim] + mlp_dims + [output_dim],
activation_type=activation_type,
out_activation_type=out_activation_type,
use_layernorm=use_layernorm,
)
self.time_dim = time_dim
def forward(
self,
x,
time,
cond: dict,
**kwargs,
):
"""
x: (B, Ta, Da)
time: (B,) or int, diffusion step
cond: dict with key state/rgb; more recent obs at the end
state: (B, To, Do)
rgb: (B, To, C, H, W)
TODO long term: more flexible handling of cond
"""
B, Ta, Da = x.shape
_, T_rgb, C, H, W = cond["rgb"].shape
# flatten chunk
x = x.view(B, -1)
# flatten history
state = cond["state"].view(B, -1)
# Take recent images --- sometimes we want to use fewer img_cond_steps than cond_steps (e.g., 1 image but 3 prio)
rgb = cond["rgb"][:, -self.img_cond_steps :]
# concatenate images in cond by channels
if self.num_img > 1:
rgb = rgb.reshape(B, T_rgb, self.num_img, 3, H, W)
rgb = einops.rearrange(rgb, "b t n c h w -> b n (t c) h w")
else:
rgb = einops.rearrange(rgb, "b t c h w -> b (t c) h w")
# convert rgb to float32 for augmentation
rgb = rgb.float()
# get vit output - pass in two images separately
if self.num_img > 1: # TODO: properly handle multiple images
rgb1 = rgb[:, 0]
rgb2 = rgb[:, 1]
if self.augment:
rgb1 = self.aug(rgb1)
rgb2 = self.aug(rgb2)
feat1 = self.backbone(rgb1)
feat2 = self.backbone(rgb2)
feat1 = self.compress1.forward(feat1, state)
feat2 = self.compress2.forward(feat2, state)
feat = torch.cat([feat1, feat2], dim=-1)
else: # single image
if self.augment:
rgb = self.aug(rgb)
feat = self.backbone(rgb)
# compress
if isinstance(self.compress, SpatialEmb):
feat = self.compress.forward(feat, state)
else:
feat = feat.flatten(1, -1)
feat = self.compress(feat)
cond_encoded = torch.cat([feat, state], dim=-1)
# append time and cond
time = time.view(B, 1)
time_emb = self.time_embedding(time).view(B, self.time_dim)
x = torch.cat([x, time_emb, cond_encoded], dim=-1)
# mlp
out = self.mlp_mean(x)
return out.view(B, Ta, Da)
class DiffusionMLP(nn.Module):
def __init__(
self,
transition_dim,
horizon_steps,
cond_dim,
time_dim=16,
mlp_dims=[256, 256],
cond_mlp_dims=None,
activation_type="Mish",
out_activation_type="Identity",
use_layernorm=False,
residual_style=False,
):
super().__init__()
output_dim = transition_dim * horizon_steps
self.time_embedding = nn.Sequential(
SinusoidalPosEmb(time_dim),
nn.Linear(time_dim, time_dim * 2),
nn.Mish(),
nn.Linear(time_dim * 2, time_dim),
)
if residual_style:
model = ResidualMLP
else:
model = MLP
if cond_mlp_dims is not None:
self.cond_mlp = MLP(
[cond_dim] + cond_mlp_dims,
activation_type=activation_type,
out_activation_type="Identity",
)
input_dim = time_dim + transition_dim * horizon_steps + cond_mlp_dims[-1]
else:
input_dim = time_dim + transition_dim * horizon_steps + cond_dim
self.mlp_mean = model(
[input_dim] + mlp_dims + [output_dim],
activation_type=activation_type,
out_activation_type=out_activation_type,
use_layernorm=use_layernorm,
)
self.time_dim = time_dim
def forward(
self,
x,
time,
cond,
**kwargs,
):
"""
x: (B, Ta, Da)
time: (B,) or int, diffusion step
cond: dict with key state/rgb; more recent obs at the end
state: (B, To, Do)
"""
B, Ta, Da = x.shape
# flatten chunk
x = x.view(B, -1)
# flatten history
state = cond["state"].view(B, -1)
# obs encoder
if hasattr(self, "cond_mlp"):
state = self.cond_mlp(state)
# append time and cond
time = time.view(B, 1)
time_emb = self.time_embedding(time).view(B, self.time_dim)
x = torch.cat([x, time_emb, state], dim=-1)
# mlp head
out = self.mlp_mean(x)
return out.view(B, Ta, Da)