support varying img size

This commit is contained in:
allenzren 2024-09-16 17:55:31 -04:00
parent 64595baca9
commit 1aaa6c2302
18 changed files with 131 additions and 81 deletions

View File

@ -159,7 +159,7 @@ To use DDIM fine-tuning, set `denoising_steps=100` in pre-training and set `mode
## Adding your own dataset/environment ## Adding your own dataset/environment
### Pre-training data ### Pre-training data
Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel) and `traj_lengths`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_lengths` is a 1-D array for indexing across num_total_steps. Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel; img_h = img_w and a multiple of 8) and `traj_lengths`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_lengths` is a 1-D array for indexing across num_total_steps.
<!-- One pre-processing example can be found at [`script/process_robomimic_dataset.py`](script/process_robomimic_dataset.py). --> <!-- One pre-processing example can be found at [`script/process_robomimic_dataset.py`](script/process_robomimic_dataset.py). -->
<!-- **Note:** The current implementation does not support loading history observations (only using observation at the current timestep). If needed, you can modify [here](agent/dataset/sequence.py#L130-L131). --> <!-- **Note:** The current implementation does not support loading history observations (only using observation at the current timestep). If needed, you can modify [here](agent/dataset/sequence.py#L130-L131). -->

View File

@ -244,15 +244,16 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
.float() .float()
.to(self.device) .to(self.device)
} }
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
)
advantages_trajs = np.zeros_like(reward_trajs) advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0 lastgaelam = 0
for t in reversed(range(self.n_steps)): for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1: if t == self.n_steps - 1:
nextvalues = next_value nextvalues = (
self.model.critic(obs_venv_ts)
.reshape(1, -1)
.cpu()
.numpy()
)
else: else:
nextvalues = values_trajs[t + 1] nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - dones_trajs[t] nonterminal = 1.0 - dones_trajs[t]

View File

@ -240,18 +240,16 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
key: torch.from_numpy(obs_venv[key]).float().to(self.device) key: torch.from_numpy(obs_venv[key]).float().to(self.device)
for key in self.obs_dims for key in self.obs_dims
} }
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts, no_augment=True)
.reshape(1, -1)
.cpu()
.numpy()
)
advantages_trajs = np.zeros_like(reward_trajs) advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0 lastgaelam = 0
for t in reversed(range(self.n_steps)): for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1: if t == self.n_steps - 1:
nextvalues = next_value nextvalues = (
self.model.critic(obs_venv_ts, no_augment=True)
.reshape(1, -1)
.cpu()
.numpy()
)
else: else:
nextvalues = values_trajs[t + 1] nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - dones_trajs[t] nonterminal = 1.0 - dones_trajs[t]

View File

@ -220,15 +220,16 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
.float() .float()
.to(self.device) .to(self.device)
} }
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
)
advantages_trajs = np.zeros_like(reward_trajs) advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0 lastgaelam = 0
for t in reversed(range(self.n_steps)): for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1: if t == self.n_steps - 1:
nextvalues = next_value nextvalues = (
self.model.critic(obs_venv_ts)
.reshape(1, -1)
.cpu()
.numpy()
)
else: else:
nextvalues = values_trajs[t + 1] nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - dones_trajs[t] nonterminal = 1.0 - dones_trajs[t]
@ -241,10 +242,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
# A = delta_t + gamma*lamdba*delta_{t+1} + ... # A = delta_t + gamma*lamdba*delta_{t+1} + ...
advantages_trajs[t] = lastgaelam = ( advantages_trajs[t] = lastgaelam = (
delta delta
+ self.gamma + self.gamma * self.gae_lambda * nonterminal * lastgaelam
* self.gae_lambda
* nonterminal
* lastgaelam
) )
returns_trajs = advantages_trajs + values_trajs returns_trajs = advantages_trajs + values_trajs

View File

@ -209,15 +209,16 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
.float() .float()
.to(self.device) .to(self.device)
} }
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
)
advantages_trajs = np.zeros_like(reward_trajs) advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0 lastgaelam = 0
for t in reversed(range(self.n_steps)): for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1: if t == self.n_steps - 1:
nextvalues = next_value nextvalues = (
self.model.critic(obs_venv_ts)
.reshape(1, -1)
.cpu()
.numpy()
)
else: else:
nextvalues = values_trajs[t + 1] nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - dones_trajs[t] nonterminal = 1.0 - dones_trajs[t]

View File

@ -228,18 +228,16 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
key: torch.from_numpy(obs_venv[key]).float().to(self.device) key: torch.from_numpy(obs_venv[key]).float().to(self.device)
for key in self.obs_dims for key in self.obs_dims
} }
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts, no_augment=True)
.reshape(1, -1)
.cpu()
.numpy()
)
advantages_trajs = np.zeros_like(reward_trajs) advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0 lastgaelam = 0
for t in reversed(range(self.n_steps)): for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1: if t == self.n_steps - 1:
nextvalues = next_value nextvalues = (
self.model.critic(obs_venv_ts, no_augment=True)
.reshape(1, -1)
.cpu()
.numpy()
)
else: else:
nextvalues = values_trajs[t + 1] nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - dones_trajs[t] nonterminal = 1.0 - dones_trajs[t]

View File

@ -122,7 +122,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -146,7 +148,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1

View File

@ -101,7 +101,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -128,7 +130,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1

View File

@ -122,7 +122,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -146,7 +148,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1

View File

@ -101,7 +101,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -128,7 +130,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1

View File

@ -122,7 +122,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -146,7 +148,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1

View File

@ -101,7 +101,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -128,7 +130,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1

View File

@ -126,7 +126,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -152,7 +154,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1

View File

@ -105,7 +105,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -134,7 +136,9 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1

View File

@ -122,7 +122,6 @@ class ViTCritic(CriticObs):
cond_dim, cond_dim,
img_cond_steps=1, img_cond_steps=1,
spatial_emb=128, spatial_emb=128,
patch_repr_dim=128,
dropout=0, dropout=0,
augment=False, augment=False,
num_img=1, num_img=1,
@ -136,8 +135,8 @@ class ViTCritic(CriticObs):
self.img_cond_steps = img_cond_steps self.img_cond_steps = img_cond_steps
if num_img > 1: if num_img > 1:
self.compress1 = SpatialEmb( self.compress1 = SpatialEmb(
num_patch=121, # TODO: repr_dim // patch_repr_dim, num_patch=self.backbone.num_patch,
patch_dim=patch_repr_dim, patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim, prop_dim=cond_dim,
proj_dim=spatial_emb, proj_dim=spatial_emb,
dropout=dropout, dropout=dropout,
@ -145,8 +144,8 @@ class ViTCritic(CriticObs):
self.compress2 = deepcopy(self.compress1) self.compress2 = deepcopy(self.compress1)
else: # TODO: clean up else: # TODO: clean up
self.compress = SpatialEmb( self.compress = SpatialEmb(
num_patch=121, num_patch=self.backbone.num_patch,
patch_dim=patch_repr_dim, patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim, prop_dim=cond_dim,
proj_dim=spatial_emb, proj_dim=spatial_emb,
dropout=dropout, dropout=dropout,

View File

@ -32,8 +32,6 @@ class Gaussian_VisionMLP(nn.Module):
std_max=1, std_max=1,
spatial_emb=0, spatial_emb=0,
visual_feature_dim=128, visual_feature_dim=128,
repr_dim=96 * 96,
patch_repr_dim=128,
dropout=0, dropout=0,
num_img=1, num_img=1,
augment=False, augment=False,
@ -51,8 +49,8 @@ class Gaussian_VisionMLP(nn.Module):
assert spatial_emb > 1, "this is the dimension" assert spatial_emb > 1, "this is the dimension"
if num_img > 1: if num_img > 1:
self.compress1 = SpatialEmb( self.compress1 = SpatialEmb(
num_patch=121, # TODO: repr_dim // patch_repr_dim, num_patch=self.backbone.num_patch,
patch_dim=patch_repr_dim, patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim, prop_dim=cond_dim,
proj_dim=spatial_emb, proj_dim=spatial_emb,
dropout=dropout, dropout=dropout,
@ -60,8 +58,8 @@ class Gaussian_VisionMLP(nn.Module):
self.compress2 = deepcopy(self.compress1) self.compress2 = deepcopy(self.compress1)
else: # TODO: clean up else: # TODO: clean up
self.compress = SpatialEmb( self.compress = SpatialEmb(
num_patch=121, num_patch=self.backbone.num_patch,
patch_dim=patch_repr_dim, patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim, prop_dim=cond_dim,
proj_dim=spatial_emb, proj_dim=spatial_emb,
dropout=dropout, dropout=dropout,
@ -69,7 +67,7 @@ class Gaussian_VisionMLP(nn.Module):
visual_feature_dim = spatial_emb * num_img visual_feature_dim = spatial_emb * num_img
else: else:
self.compress = nn.Sequential( self.compress = nn.Sequential(
nn.Linear(repr_dim, visual_feature_dim), nn.Linear(self.backbone.repr_dim, visual_feature_dim),
nn.LayerNorm(visual_feature_dim), nn.LayerNorm(visual_feature_dim),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.ReLU(), nn.ReLU(),

View File

@ -3,12 +3,13 @@ ViT image encoder implementation from IBRL, https://github.com/hengyuan-hu/ibrl
""" """
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import List from typing import List
import einops import einops
import torch import torch
from torch import nn from torch import nn
from torch.nn.init import trunc_normal_ from torch.nn.init import trunc_normal_
import math
@dataclass @dataclass
@ -29,6 +30,8 @@ class VitEncoder(nn.Module):
obs_shape: List[int], obs_shape: List[int],
cfg: VitEncoderConfig, cfg: VitEncoderConfig,
num_channel=3, num_channel=3,
img_h=96,
img_w=96,
): ):
super().__init__() super().__init__()
self.obs_shape = obs_shape self.obs_shape = obs_shape
@ -40,8 +43,11 @@ class VitEncoder(nn.Module):
num_head=cfg.num_heads, num_head=cfg.num_heads,
depth=cfg.depth, depth=cfg.depth,
num_channel=num_channel, num_channel=num_channel,
img_h=img_h,
img_w=img_w,
) )
self.img_h = img_h
self.img_w = img_w
self.num_patch = self.vit.num_patches self.num_patch = self.vit.num_patches
self.patch_repr_dim = self.cfg.embed_dim self.patch_repr_dim = self.cfg.embed_dim
self.repr_dim = self.cfg.embed_dim * self.vit.num_patches self.repr_dim = self.cfg.embed_dim * self.vit.num_patches
@ -56,11 +62,11 @@ class VitEncoder(nn.Module):
class PatchEmbed1(nn.Module): class PatchEmbed1(nn.Module):
def __init__(self, embed_dim, num_channel=3): def __init__(self, embed_dim, num_channel=3, img_h=96, img_w=96):
super().__init__() super().__init__()
self.conv = nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=8) self.conv = nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=8)
self.num_patch = 144 self.num_patch = math.ceil(img_h / 8) * math.ceil(img_w / 8)
self.patch_dim = embed_dim self.patch_dim = embed_dim
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
@ -70,7 +76,7 @@ class PatchEmbed1(nn.Module):
class PatchEmbed2(nn.Module): class PatchEmbed2(nn.Module):
def __init__(self, embed_dim, use_norm, num_channel=3): def __init__(self, embed_dim, use_norm, num_channel=3, img_h=96, img_w=96):
super().__init__() super().__init__()
layers = [ layers = [
nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=4), nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=4),
@ -80,7 +86,11 @@ class PatchEmbed2(nn.Module):
] ]
self.embed = nn.Sequential(*layers) self.embed = nn.Sequential(*layers)
self.num_patch = 121 # TODO: specifically for 96x96 set by Hengyuan? H1 = math.ceil((img_h - 8) / 4) + 1
W1 = math.ceil((img_w - 8) / 4) + 1
H2 = math.ceil((H1 - 3) / 2) + 1
W2 = math.ceil((W1 - 3) / 2) + 1
self.num_patch = H2 * W2
self.patch_dim = embed_dim self.patch_dim = embed_dim
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
@ -146,14 +156,25 @@ class MinVit(nn.Module):
num_head, num_head,
depth, depth,
num_channel=3, num_channel=3,
img_h=96,
img_w=96,
): ):
super().__init__() super().__init__()
if embed_style == "embed1": if embed_style == "embed1":
self.patch_embed = PatchEmbed1(embed_dim, num_channel=num_channel) self.patch_embed = PatchEmbed1(
embed_dim,
num_channel=num_channel,
img_h=img_h,
img_w=img_w,
)
elif embed_style == "embed2": elif embed_style == "embed2":
self.patch_embed = PatchEmbed2( self.patch_embed = PatchEmbed2(
embed_dim, use_norm=embed_norm, num_channel=num_channel embed_dim,
use_norm=embed_norm,
num_channel=num_channel,
img_h=img_h,
img_w=img_w,
) )
else: else:
assert False assert False
@ -233,8 +254,14 @@ def test_transformer_layer():
if __name__ == "__main__": if __name__ == "__main__":
obs_shape = [6, 96, 96] obs_shape = [6, 128, 128]
enc = VitEncoder([6, 96, 96], VitEncoderConfig()) enc = VitEncoder(
obs_shape,
VitEncoderConfig(),
num_channel=obs_shape[0],
img_h=obs_shape[1],
img_w=obs_shape[2],
)
print(enc) print(enc)
x = torch.rand(1, *obs_shape) * 255 x = torch.rand(1, *obs_shape) * 255

View File

@ -34,8 +34,6 @@ class VisionDiffusionMLP(nn.Module):
residual_style=False, residual_style=False,
spatial_emb=0, spatial_emb=0,
visual_feature_dim=128, visual_feature_dim=128,
repr_dim=96 * 96,
patch_repr_dim=128,
dropout=0, dropout=0,
num_img=1, num_img=1,
augment=False, augment=False,
@ -53,8 +51,8 @@ class VisionDiffusionMLP(nn.Module):
assert spatial_emb > 1, "this is the dimension" assert spatial_emb > 1, "this is the dimension"
if num_img > 1: if num_img > 1:
self.compress1 = SpatialEmb( self.compress1 = SpatialEmb(
num_patch=121, # TODO: repr_dim // patch_repr_dim, num_patch=self.backbone.num_patch,
patch_dim=patch_repr_dim, patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim, prop_dim=cond_dim,
proj_dim=spatial_emb, proj_dim=spatial_emb,
dropout=dropout, dropout=dropout,
@ -62,8 +60,8 @@ class VisionDiffusionMLP(nn.Module):
self.compress2 = deepcopy(self.compress1) self.compress2 = deepcopy(self.compress1)
else: # TODO: clean up else: # TODO: clean up
self.compress = SpatialEmb( self.compress = SpatialEmb(
num_patch=121, num_patch=self.backbone.num_patch,
patch_dim=patch_repr_dim, patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim, prop_dim=cond_dim,
proj_dim=spatial_emb, proj_dim=spatial_emb,
dropout=dropout, dropout=dropout,
@ -71,7 +69,7 @@ class VisionDiffusionMLP(nn.Module):
visual_feature_dim = spatial_emb * num_img visual_feature_dim = spatial_emb * num_img
else: else:
self.compress = nn.Sequential( self.compress = nn.Sequential(
nn.Linear(repr_dim, visual_feature_dim), nn.Linear(self.backbone.repr_dim, visual_feature_dim),
nn.LayerNorm(visual_feature_dim), nn.LayerNorm(visual_feature_dim),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.ReLU(), nn.ReLU(),