From 1aaa6c2302382c39885378c1a3b1eb72d7b940cf Mon Sep 17 00:00:00 2001 From: allenzren Date: Mon, 16 Sep 2024 17:55:31 -0400 Subject: [PATCH] support varying img size --- README.md | 2 +- agent/finetune/train_ppo_diffusion_agent.py | 11 +++-- .../finetune/train_ppo_diffusion_img_agent.py | 14 +++--- .../train_ppo_exact_diffusion_agent.py | 16 +++---- agent/finetune/train_ppo_gaussian_agent.py | 11 +++-- .../finetune/train_ppo_gaussian_img_agent.py | 14 +++--- .../can/ft_ppo_diffusion_mlp_img.yaml | 8 +++- .../finetune/can/ft_ppo_gaussian_mlp_img.yaml | 8 +++- .../lift/ft_ppo_diffusion_mlp_img.yaml | 8 +++- .../lift/ft_ppo_gaussian_mlp_img.yaml | 8 +++- .../square/ft_ppo_diffusion_mlp_img.yaml | 8 +++- .../square/ft_ppo_gaussian_mlp_img.yaml | 8 +++- .../transport/ft_ppo_diffusion_mlp_img.yaml | 8 +++- .../transport/ft_ppo_gaussian_mlp_img.yaml | 8 +++- model/common/critic.py | 9 ++-- model/common/mlp_gaussian.py | 12 ++--- model/common/vit.py | 47 +++++++++++++++---- model/diffusion/mlp_diffusion.py | 12 ++--- 18 files changed, 131 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index 9a2e4d1..fbcfaee 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,7 @@ To use DDIM fine-tuning, set `denoising_steps=100` in pre-training and set `mode ## Adding your own dataset/environment ### Pre-training data -Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel) and `traj_lengths`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_lengths` is a 1-D array for indexing across num_total_steps. +Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel; img_h = img_w and a multiple of 8) and `traj_lengths`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_lengths` is a 1-D array for indexing across num_total_steps. diff --git a/agent/finetune/train_ppo_diffusion_agent.py b/agent/finetune/train_ppo_diffusion_agent.py index 1897933..73878ce 100644 --- a/agent/finetune/train_ppo_diffusion_agent.py +++ b/agent/finetune/train_ppo_diffusion_agent.py @@ -244,15 +244,16 @@ class TrainPPODiffusionAgent(TrainPPOAgent): .float() .to(self.device) } - with torch.no_grad(): - next_value = ( - self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy() - ) advantages_trajs = np.zeros_like(reward_trajs) lastgaelam = 0 for t in reversed(range(self.n_steps)): if t == self.n_steps - 1: - nextvalues = next_value + nextvalues = ( + self.model.critic(obs_venv_ts) + .reshape(1, -1) + .cpu() + .numpy() + ) else: nextvalues = values_trajs[t + 1] nonterminal = 1.0 - dones_trajs[t] diff --git a/agent/finetune/train_ppo_diffusion_img_agent.py b/agent/finetune/train_ppo_diffusion_img_agent.py index ed85b60..2abb148 100644 --- a/agent/finetune/train_ppo_diffusion_img_agent.py +++ b/agent/finetune/train_ppo_diffusion_img_agent.py @@ -240,18 +240,16 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): key: torch.from_numpy(obs_venv[key]).float().to(self.device) for key in self.obs_dims } - with torch.no_grad(): - next_value = ( - self.model.critic(obs_venv_ts, no_augment=True) - .reshape(1, -1) - .cpu() - .numpy() - ) advantages_trajs = np.zeros_like(reward_trajs) lastgaelam = 0 for t in reversed(range(self.n_steps)): if t == self.n_steps - 1: - nextvalues = next_value + nextvalues = ( + self.model.critic(obs_venv_ts, no_augment=True) + .reshape(1, -1) + .cpu() + .numpy() + ) else: nextvalues = values_trajs[t + 1] nonterminal = 1.0 - dones_trajs[t] diff --git a/agent/finetune/train_ppo_exact_diffusion_agent.py b/agent/finetune/train_ppo_exact_diffusion_agent.py index 9661ec9..00c4c90 100644 --- a/agent/finetune/train_ppo_exact_diffusion_agent.py +++ b/agent/finetune/train_ppo_exact_diffusion_agent.py @@ -220,15 +220,16 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): .float() .to(self.device) } - with torch.no_grad(): - next_value = ( - self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy() - ) advantages_trajs = np.zeros_like(reward_trajs) lastgaelam = 0 for t in reversed(range(self.n_steps)): if t == self.n_steps - 1: - nextvalues = next_value + nextvalues = ( + self.model.critic(obs_venv_ts) + .reshape(1, -1) + .cpu() + .numpy() + ) else: nextvalues = values_trajs[t + 1] nonterminal = 1.0 - dones_trajs[t] @@ -241,10 +242,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): # A = delta_t + gamma*lamdba*delta_{t+1} + ... advantages_trajs[t] = lastgaelam = ( delta - + self.gamma - * self.gae_lambda - * nonterminal - * lastgaelam + + self.gamma * self.gae_lambda * nonterminal * lastgaelam ) returns_trajs = advantages_trajs + values_trajs diff --git a/agent/finetune/train_ppo_gaussian_agent.py b/agent/finetune/train_ppo_gaussian_agent.py index f871b6b..cfc1546 100644 --- a/agent/finetune/train_ppo_gaussian_agent.py +++ b/agent/finetune/train_ppo_gaussian_agent.py @@ -209,15 +209,16 @@ class TrainPPOGaussianAgent(TrainPPOAgent): .float() .to(self.device) } - with torch.no_grad(): - next_value = ( - self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy() - ) advantages_trajs = np.zeros_like(reward_trajs) lastgaelam = 0 for t in reversed(range(self.n_steps)): if t == self.n_steps - 1: - nextvalues = next_value + nextvalues = ( + self.model.critic(obs_venv_ts) + .reshape(1, -1) + .cpu() + .numpy() + ) else: nextvalues = values_trajs[t + 1] nonterminal = 1.0 - dones_trajs[t] diff --git a/agent/finetune/train_ppo_gaussian_img_agent.py b/agent/finetune/train_ppo_gaussian_img_agent.py index 964031d..eabd531 100644 --- a/agent/finetune/train_ppo_gaussian_img_agent.py +++ b/agent/finetune/train_ppo_gaussian_img_agent.py @@ -228,18 +228,16 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent): key: torch.from_numpy(obs_venv[key]).float().to(self.device) for key in self.obs_dims } - with torch.no_grad(): - next_value = ( - self.model.critic(obs_venv_ts, no_augment=True) - .reshape(1, -1) - .cpu() - .numpy() - ) advantages_trajs = np.zeros_like(reward_trajs) lastgaelam = 0 for t in reversed(range(self.n_steps)): if t == self.n_steps - 1: - nextvalues = next_value + nextvalues = ( + self.model.critic(obs_venv_ts, no_augment=True) + .reshape(1, -1) + .cpu() + .numpy() + ) else: nextvalues = values_trajs[t + 1] nonterminal = 1.0 - dones_trajs[t] diff --git a/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp_img.yaml b/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp_img.yaml index d47e664..ef77c2f 100644 --- a/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp_img.yaml +++ b/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp_img.yaml @@ -122,7 +122,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 @@ -146,7 +148,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 diff --git a/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp_img.yaml b/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp_img.yaml index 9613b10..ea9b229 100644 --- a/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp_img.yaml +++ b/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp_img.yaml @@ -101,7 +101,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 @@ -128,7 +130,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 diff --git a/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp_img.yaml b/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp_img.yaml index fe36cb5..03f9f0b 100644 --- a/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp_img.yaml +++ b/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp_img.yaml @@ -122,7 +122,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 @@ -146,7 +148,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 diff --git a/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp_img.yaml b/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp_img.yaml index 3432e4b..26a9a52 100644 --- a/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp_img.yaml +++ b/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp_img.yaml @@ -101,7 +101,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 @@ -128,7 +130,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 diff --git a/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp_img.yaml b/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp_img.yaml index 1e8ebfa..fa1db67 100644 --- a/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp_img.yaml +++ b/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp_img.yaml @@ -122,7 +122,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 @@ -146,7 +148,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 diff --git a/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp_img.yaml b/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp_img.yaml index 7bbe198..0ffddd7 100644 --- a/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp_img.yaml +++ b/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp_img.yaml @@ -101,7 +101,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 @@ -128,7 +130,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 diff --git a/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp_img.yaml b/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp_img.yaml index 87f336a..19c2b72 100644 --- a/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp_img.yaml +++ b/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp_img.yaml @@ -126,7 +126,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 @@ -152,7 +154,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 diff --git a/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp_img.yaml b/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp_img.yaml index 9503c1d..dac3e6e 100644 --- a/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp_img.yaml +++ b/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp_img.yaml @@ -105,7 +105,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 @@ -134,7 +136,9 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} - num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated + num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated + img_h: ${shape_meta.obs.rgb.shape[1]} + img_w: ${shape_meta.obs.rgb.shape[2]} cfg: patch_size: 8 depth: 1 diff --git a/model/common/critic.py b/model/common/critic.py index e425bd3..5552f37 100644 --- a/model/common/critic.py +++ b/model/common/critic.py @@ -122,7 +122,6 @@ class ViTCritic(CriticObs): cond_dim, img_cond_steps=1, spatial_emb=128, - patch_repr_dim=128, dropout=0, augment=False, num_img=1, @@ -136,8 +135,8 @@ class ViTCritic(CriticObs): self.img_cond_steps = img_cond_steps if num_img > 1: self.compress1 = SpatialEmb( - num_patch=121, # TODO: repr_dim // patch_repr_dim, - patch_dim=patch_repr_dim, + num_patch=self.backbone.num_patch, + patch_dim=self.backbone.patch_repr_dim, prop_dim=cond_dim, proj_dim=spatial_emb, dropout=dropout, @@ -145,8 +144,8 @@ class ViTCritic(CriticObs): self.compress2 = deepcopy(self.compress1) else: # TODO: clean up self.compress = SpatialEmb( - num_patch=121, - patch_dim=patch_repr_dim, + num_patch=self.backbone.num_patch, + patch_dim=self.backbone.patch_repr_dim, prop_dim=cond_dim, proj_dim=spatial_emb, dropout=dropout, diff --git a/model/common/mlp_gaussian.py b/model/common/mlp_gaussian.py index 73adb16..a60a1d8 100644 --- a/model/common/mlp_gaussian.py +++ b/model/common/mlp_gaussian.py @@ -32,8 +32,6 @@ class Gaussian_VisionMLP(nn.Module): std_max=1, spatial_emb=0, visual_feature_dim=128, - repr_dim=96 * 96, - patch_repr_dim=128, dropout=0, num_img=1, augment=False, @@ -51,8 +49,8 @@ class Gaussian_VisionMLP(nn.Module): assert spatial_emb > 1, "this is the dimension" if num_img > 1: self.compress1 = SpatialEmb( - num_patch=121, # TODO: repr_dim // patch_repr_dim, - patch_dim=patch_repr_dim, + num_patch=self.backbone.num_patch, + patch_dim=self.backbone.patch_repr_dim, prop_dim=cond_dim, proj_dim=spatial_emb, dropout=dropout, @@ -60,8 +58,8 @@ class Gaussian_VisionMLP(nn.Module): self.compress2 = deepcopy(self.compress1) else: # TODO: clean up self.compress = SpatialEmb( - num_patch=121, - patch_dim=patch_repr_dim, + num_patch=self.backbone.num_patch, + patch_dim=self.backbone.patch_repr_dim, prop_dim=cond_dim, proj_dim=spatial_emb, dropout=dropout, @@ -69,7 +67,7 @@ class Gaussian_VisionMLP(nn.Module): visual_feature_dim = spatial_emb * num_img else: self.compress = nn.Sequential( - nn.Linear(repr_dim, visual_feature_dim), + nn.Linear(self.backbone.repr_dim, visual_feature_dim), nn.LayerNorm(visual_feature_dim), nn.Dropout(dropout), nn.ReLU(), diff --git a/model/common/vit.py b/model/common/vit.py index ec6f2ea..1c0eae3 100644 --- a/model/common/vit.py +++ b/model/common/vit.py @@ -3,12 +3,13 @@ ViT image encoder implementation from IBRL, https://github.com/hengyuan-hu/ibrl """ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import List import einops import torch from torch import nn from torch.nn.init import trunc_normal_ +import math @dataclass @@ -29,6 +30,8 @@ class VitEncoder(nn.Module): obs_shape: List[int], cfg: VitEncoderConfig, num_channel=3, + img_h=96, + img_w=96, ): super().__init__() self.obs_shape = obs_shape @@ -40,8 +43,11 @@ class VitEncoder(nn.Module): num_head=cfg.num_heads, depth=cfg.depth, num_channel=num_channel, + img_h=img_h, + img_w=img_w, ) - + self.img_h = img_h + self.img_w = img_w self.num_patch = self.vit.num_patches self.patch_repr_dim = self.cfg.embed_dim self.repr_dim = self.cfg.embed_dim * self.vit.num_patches @@ -56,11 +62,11 @@ class VitEncoder(nn.Module): class PatchEmbed1(nn.Module): - def __init__(self, embed_dim, num_channel=3): + def __init__(self, embed_dim, num_channel=3, img_h=96, img_w=96): super().__init__() self.conv = nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=8) - self.num_patch = 144 + self.num_patch = math.ceil(img_h / 8) * math.ceil(img_w / 8) self.patch_dim = embed_dim def forward(self, x: torch.Tensor): @@ -70,7 +76,7 @@ class PatchEmbed1(nn.Module): class PatchEmbed2(nn.Module): - def __init__(self, embed_dim, use_norm, num_channel=3): + def __init__(self, embed_dim, use_norm, num_channel=3, img_h=96, img_w=96): super().__init__() layers = [ nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=4), @@ -80,7 +86,11 @@ class PatchEmbed2(nn.Module): ] self.embed = nn.Sequential(*layers) - self.num_patch = 121 # TODO: specifically for 96x96 set by Hengyuan? + H1 = math.ceil((img_h - 8) / 4) + 1 + W1 = math.ceil((img_w - 8) / 4) + 1 + H2 = math.ceil((H1 - 3) / 2) + 1 + W2 = math.ceil((W1 - 3) / 2) + 1 + self.num_patch = H2 * W2 self.patch_dim = embed_dim def forward(self, x: torch.Tensor): @@ -146,14 +156,25 @@ class MinVit(nn.Module): num_head, depth, num_channel=3, + img_h=96, + img_w=96, ): super().__init__() if embed_style == "embed1": - self.patch_embed = PatchEmbed1(embed_dim, num_channel=num_channel) + self.patch_embed = PatchEmbed1( + embed_dim, + num_channel=num_channel, + img_h=img_h, + img_w=img_w, + ) elif embed_style == "embed2": self.patch_embed = PatchEmbed2( - embed_dim, use_norm=embed_norm, num_channel=num_channel + embed_dim, + use_norm=embed_norm, + num_channel=num_channel, + img_h=img_h, + img_w=img_w, ) else: assert False @@ -233,8 +254,14 @@ def test_transformer_layer(): if __name__ == "__main__": - obs_shape = [6, 96, 96] - enc = VitEncoder([6, 96, 96], VitEncoderConfig()) + obs_shape = [6, 128, 128] + enc = VitEncoder( + obs_shape, + VitEncoderConfig(), + num_channel=obs_shape[0], + img_h=obs_shape[1], + img_w=obs_shape[2], + ) print(enc) x = torch.rand(1, *obs_shape) * 255 diff --git a/model/diffusion/mlp_diffusion.py b/model/diffusion/mlp_diffusion.py index 5b6174b..9bc940e 100644 --- a/model/diffusion/mlp_diffusion.py +++ b/model/diffusion/mlp_diffusion.py @@ -34,8 +34,6 @@ class VisionDiffusionMLP(nn.Module): residual_style=False, spatial_emb=0, visual_feature_dim=128, - repr_dim=96 * 96, - patch_repr_dim=128, dropout=0, num_img=1, augment=False, @@ -53,8 +51,8 @@ class VisionDiffusionMLP(nn.Module): assert spatial_emb > 1, "this is the dimension" if num_img > 1: self.compress1 = SpatialEmb( - num_patch=121, # TODO: repr_dim // patch_repr_dim, - patch_dim=patch_repr_dim, + num_patch=self.backbone.num_patch, + patch_dim=self.backbone.patch_repr_dim, prop_dim=cond_dim, proj_dim=spatial_emb, dropout=dropout, @@ -62,8 +60,8 @@ class VisionDiffusionMLP(nn.Module): self.compress2 = deepcopy(self.compress1) else: # TODO: clean up self.compress = SpatialEmb( - num_patch=121, - patch_dim=patch_repr_dim, + num_patch=self.backbone.num_patch, + patch_dim=self.backbone.patch_repr_dim, prop_dim=cond_dim, proj_dim=spatial_emb, dropout=dropout, @@ -71,7 +69,7 @@ class VisionDiffusionMLP(nn.Module): visual_feature_dim = spatial_emb * num_img else: self.compress = nn.Sequential( - nn.Linear(repr_dim, visual_feature_dim), + nn.Linear(self.backbone.repr_dim, visual_feature_dim), nn.LayerNorm(visual_feature_dim), nn.Dropout(dropout), nn.ReLU(),