support varying img size
This commit is contained in:
parent
64595baca9
commit
1aaa6c2302
@ -159,7 +159,7 @@ To use DDIM fine-tuning, set `denoising_steps=100` in pre-training and set `mode
|
||||
## Adding your own dataset/environment
|
||||
|
||||
### Pre-training data
|
||||
Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel) and `traj_lengths`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_lengths` is a 1-D array for indexing across num_total_steps.
|
||||
Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel; img_h = img_w and a multiple of 8) and `traj_lengths`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_lengths` is a 1-D array for indexing across num_total_steps.
|
||||
<!-- One pre-processing example can be found at [`script/process_robomimic_dataset.py`](script/process_robomimic_dataset.py). -->
|
||||
|
||||
<!-- **Note:** The current implementation does not support loading history observations (only using observation at the current timestep). If needed, you can modify [here](agent/dataset/sequence.py#L130-L131). -->
|
||||
|
@ -244,15 +244,16 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
with torch.no_grad():
|
||||
next_value = (
|
||||
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
||||
)
|
||||
advantages_trajs = np.zeros_like(reward_trajs)
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(self.n_steps)):
|
||||
if t == self.n_steps - 1:
|
||||
nextvalues = next_value
|
||||
nextvalues = (
|
||||
self.model.critic(obs_venv_ts)
|
||||
.reshape(1, -1)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
else:
|
||||
nextvalues = values_trajs[t + 1]
|
||||
nonterminal = 1.0 - dones_trajs[t]
|
||||
|
@ -240,18 +240,16 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
|
||||
for key in self.obs_dims
|
||||
}
|
||||
with torch.no_grad():
|
||||
next_value = (
|
||||
self.model.critic(obs_venv_ts, no_augment=True)
|
||||
.reshape(1, -1)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
advantages_trajs = np.zeros_like(reward_trajs)
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(self.n_steps)):
|
||||
if t == self.n_steps - 1:
|
||||
nextvalues = next_value
|
||||
nextvalues = (
|
||||
self.model.critic(obs_venv_ts, no_augment=True)
|
||||
.reshape(1, -1)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
else:
|
||||
nextvalues = values_trajs[t + 1]
|
||||
nonterminal = 1.0 - dones_trajs[t]
|
||||
|
@ -220,15 +220,16 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
with torch.no_grad():
|
||||
next_value = (
|
||||
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
||||
)
|
||||
advantages_trajs = np.zeros_like(reward_trajs)
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(self.n_steps)):
|
||||
if t == self.n_steps - 1:
|
||||
nextvalues = next_value
|
||||
nextvalues = (
|
||||
self.model.critic(obs_venv_ts)
|
||||
.reshape(1, -1)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
else:
|
||||
nextvalues = values_trajs[t + 1]
|
||||
nonterminal = 1.0 - dones_trajs[t]
|
||||
@ -241,10 +242,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
# A = delta_t + gamma*lamdba*delta_{t+1} + ...
|
||||
advantages_trajs[t] = lastgaelam = (
|
||||
delta
|
||||
+ self.gamma
|
||||
* self.gae_lambda
|
||||
* nonterminal
|
||||
* lastgaelam
|
||||
+ self.gamma * self.gae_lambda * nonterminal * lastgaelam
|
||||
)
|
||||
returns_trajs = advantages_trajs + values_trajs
|
||||
|
||||
|
@ -209,15 +209,16 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
with torch.no_grad():
|
||||
next_value = (
|
||||
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
||||
)
|
||||
advantages_trajs = np.zeros_like(reward_trajs)
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(self.n_steps)):
|
||||
if t == self.n_steps - 1:
|
||||
nextvalues = next_value
|
||||
nextvalues = (
|
||||
self.model.critic(obs_venv_ts)
|
||||
.reshape(1, -1)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
else:
|
||||
nextvalues = values_trajs[t + 1]
|
||||
nonterminal = 1.0 - dones_trajs[t]
|
||||
|
@ -228,18 +228,16 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
||||
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
|
||||
for key in self.obs_dims
|
||||
}
|
||||
with torch.no_grad():
|
||||
next_value = (
|
||||
self.model.critic(obs_venv_ts, no_augment=True)
|
||||
.reshape(1, -1)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
advantages_trajs = np.zeros_like(reward_trajs)
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(self.n_steps)):
|
||||
if t == self.n_steps - 1:
|
||||
nextvalues = next_value
|
||||
nextvalues = (
|
||||
self.model.critic(obs_venv_ts, no_augment=True)
|
||||
.reshape(1, -1)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
else:
|
||||
nextvalues = values_trajs[t + 1]
|
||||
nonterminal = 1.0 - dones_trajs[t]
|
||||
|
@ -122,7 +122,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -146,7 +148,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
|
@ -101,7 +101,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -128,7 +130,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
|
@ -122,7 +122,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -146,7 +148,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
|
@ -101,7 +101,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -128,7 +130,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
|
@ -122,7 +122,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -146,7 +148,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
|
@ -101,7 +101,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -128,7 +130,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
|
@ -126,7 +126,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -152,7 +154,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
|
@ -105,7 +105,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -134,7 +136,9 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
img_h: ${shape_meta.obs.rgb.shape[1]}
|
||||
img_w: ${shape_meta.obs.rgb.shape[2]}
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
|
@ -122,7 +122,6 @@ class ViTCritic(CriticObs):
|
||||
cond_dim,
|
||||
img_cond_steps=1,
|
||||
spatial_emb=128,
|
||||
patch_repr_dim=128,
|
||||
dropout=0,
|
||||
augment=False,
|
||||
num_img=1,
|
||||
@ -136,8 +135,8 @@ class ViTCritic(CriticObs):
|
||||
self.img_cond_steps = img_cond_steps
|
||||
if num_img > 1:
|
||||
self.compress1 = SpatialEmb(
|
||||
num_patch=121, # TODO: repr_dim // patch_repr_dim,
|
||||
patch_dim=patch_repr_dim,
|
||||
num_patch=self.backbone.num_patch,
|
||||
patch_dim=self.backbone.patch_repr_dim,
|
||||
prop_dim=cond_dim,
|
||||
proj_dim=spatial_emb,
|
||||
dropout=dropout,
|
||||
@ -145,8 +144,8 @@ class ViTCritic(CriticObs):
|
||||
self.compress2 = deepcopy(self.compress1)
|
||||
else: # TODO: clean up
|
||||
self.compress = SpatialEmb(
|
||||
num_patch=121,
|
||||
patch_dim=patch_repr_dim,
|
||||
num_patch=self.backbone.num_patch,
|
||||
patch_dim=self.backbone.patch_repr_dim,
|
||||
prop_dim=cond_dim,
|
||||
proj_dim=spatial_emb,
|
||||
dropout=dropout,
|
||||
|
@ -32,8 +32,6 @@ class Gaussian_VisionMLP(nn.Module):
|
||||
std_max=1,
|
||||
spatial_emb=0,
|
||||
visual_feature_dim=128,
|
||||
repr_dim=96 * 96,
|
||||
patch_repr_dim=128,
|
||||
dropout=0,
|
||||
num_img=1,
|
||||
augment=False,
|
||||
@ -51,8 +49,8 @@ class Gaussian_VisionMLP(nn.Module):
|
||||
assert spatial_emb > 1, "this is the dimension"
|
||||
if num_img > 1:
|
||||
self.compress1 = SpatialEmb(
|
||||
num_patch=121, # TODO: repr_dim // patch_repr_dim,
|
||||
patch_dim=patch_repr_dim,
|
||||
num_patch=self.backbone.num_patch,
|
||||
patch_dim=self.backbone.patch_repr_dim,
|
||||
prop_dim=cond_dim,
|
||||
proj_dim=spatial_emb,
|
||||
dropout=dropout,
|
||||
@ -60,8 +58,8 @@ class Gaussian_VisionMLP(nn.Module):
|
||||
self.compress2 = deepcopy(self.compress1)
|
||||
else: # TODO: clean up
|
||||
self.compress = SpatialEmb(
|
||||
num_patch=121,
|
||||
patch_dim=patch_repr_dim,
|
||||
num_patch=self.backbone.num_patch,
|
||||
patch_dim=self.backbone.patch_repr_dim,
|
||||
prop_dim=cond_dim,
|
||||
proj_dim=spatial_emb,
|
||||
dropout=dropout,
|
||||
@ -69,7 +67,7 @@ class Gaussian_VisionMLP(nn.Module):
|
||||
visual_feature_dim = spatial_emb * num_img
|
||||
else:
|
||||
self.compress = nn.Sequential(
|
||||
nn.Linear(repr_dim, visual_feature_dim),
|
||||
nn.Linear(self.backbone.repr_dim, visual_feature_dim),
|
||||
nn.LayerNorm(visual_feature_dim),
|
||||
nn.Dropout(dropout),
|
||||
nn.ReLU(),
|
||||
|
@ -3,12 +3,13 @@ ViT image encoder implementation from IBRL, https://github.com/hengyuan-hu/ibrl
|
||||
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
import einops
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
import math
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -29,6 +30,8 @@ class VitEncoder(nn.Module):
|
||||
obs_shape: List[int],
|
||||
cfg: VitEncoderConfig,
|
||||
num_channel=3,
|
||||
img_h=96,
|
||||
img_w=96,
|
||||
):
|
||||
super().__init__()
|
||||
self.obs_shape = obs_shape
|
||||
@ -40,8 +43,11 @@ class VitEncoder(nn.Module):
|
||||
num_head=cfg.num_heads,
|
||||
depth=cfg.depth,
|
||||
num_channel=num_channel,
|
||||
img_h=img_h,
|
||||
img_w=img_w,
|
||||
)
|
||||
|
||||
self.img_h = img_h
|
||||
self.img_w = img_w
|
||||
self.num_patch = self.vit.num_patches
|
||||
self.patch_repr_dim = self.cfg.embed_dim
|
||||
self.repr_dim = self.cfg.embed_dim * self.vit.num_patches
|
||||
@ -56,11 +62,11 @@ class VitEncoder(nn.Module):
|
||||
|
||||
|
||||
class PatchEmbed1(nn.Module):
|
||||
def __init__(self, embed_dim, num_channel=3):
|
||||
def __init__(self, embed_dim, num_channel=3, img_h=96, img_w=96):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=8)
|
||||
|
||||
self.num_patch = 144
|
||||
self.num_patch = math.ceil(img_h / 8) * math.ceil(img_w / 8)
|
||||
self.patch_dim = embed_dim
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
@ -70,7 +76,7 @@ class PatchEmbed1(nn.Module):
|
||||
|
||||
|
||||
class PatchEmbed2(nn.Module):
|
||||
def __init__(self, embed_dim, use_norm, num_channel=3):
|
||||
def __init__(self, embed_dim, use_norm, num_channel=3, img_h=96, img_w=96):
|
||||
super().__init__()
|
||||
layers = [
|
||||
nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=4),
|
||||
@ -80,7 +86,11 @@ class PatchEmbed2(nn.Module):
|
||||
]
|
||||
self.embed = nn.Sequential(*layers)
|
||||
|
||||
self.num_patch = 121 # TODO: specifically for 96x96 set by Hengyuan?
|
||||
H1 = math.ceil((img_h - 8) / 4) + 1
|
||||
W1 = math.ceil((img_w - 8) / 4) + 1
|
||||
H2 = math.ceil((H1 - 3) / 2) + 1
|
||||
W2 = math.ceil((W1 - 3) / 2) + 1
|
||||
self.num_patch = H2 * W2
|
||||
self.patch_dim = embed_dim
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
@ -146,14 +156,25 @@ class MinVit(nn.Module):
|
||||
num_head,
|
||||
depth,
|
||||
num_channel=3,
|
||||
img_h=96,
|
||||
img_w=96,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if embed_style == "embed1":
|
||||
self.patch_embed = PatchEmbed1(embed_dim, num_channel=num_channel)
|
||||
self.patch_embed = PatchEmbed1(
|
||||
embed_dim,
|
||||
num_channel=num_channel,
|
||||
img_h=img_h,
|
||||
img_w=img_w,
|
||||
)
|
||||
elif embed_style == "embed2":
|
||||
self.patch_embed = PatchEmbed2(
|
||||
embed_dim, use_norm=embed_norm, num_channel=num_channel
|
||||
embed_dim,
|
||||
use_norm=embed_norm,
|
||||
num_channel=num_channel,
|
||||
img_h=img_h,
|
||||
img_w=img_w,
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
@ -233,8 +254,14 @@ def test_transformer_layer():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
obs_shape = [6, 96, 96]
|
||||
enc = VitEncoder([6, 96, 96], VitEncoderConfig())
|
||||
obs_shape = [6, 128, 128]
|
||||
enc = VitEncoder(
|
||||
obs_shape,
|
||||
VitEncoderConfig(),
|
||||
num_channel=obs_shape[0],
|
||||
img_h=obs_shape[1],
|
||||
img_w=obs_shape[2],
|
||||
)
|
||||
|
||||
print(enc)
|
||||
x = torch.rand(1, *obs_shape) * 255
|
||||
|
@ -34,8 +34,6 @@ class VisionDiffusionMLP(nn.Module):
|
||||
residual_style=False,
|
||||
spatial_emb=0,
|
||||
visual_feature_dim=128,
|
||||
repr_dim=96 * 96,
|
||||
patch_repr_dim=128,
|
||||
dropout=0,
|
||||
num_img=1,
|
||||
augment=False,
|
||||
@ -53,8 +51,8 @@ class VisionDiffusionMLP(nn.Module):
|
||||
assert spatial_emb > 1, "this is the dimension"
|
||||
if num_img > 1:
|
||||
self.compress1 = SpatialEmb(
|
||||
num_patch=121, # TODO: repr_dim // patch_repr_dim,
|
||||
patch_dim=patch_repr_dim,
|
||||
num_patch=self.backbone.num_patch,
|
||||
patch_dim=self.backbone.patch_repr_dim,
|
||||
prop_dim=cond_dim,
|
||||
proj_dim=spatial_emb,
|
||||
dropout=dropout,
|
||||
@ -62,8 +60,8 @@ class VisionDiffusionMLP(nn.Module):
|
||||
self.compress2 = deepcopy(self.compress1)
|
||||
else: # TODO: clean up
|
||||
self.compress = SpatialEmb(
|
||||
num_patch=121,
|
||||
patch_dim=patch_repr_dim,
|
||||
num_patch=self.backbone.num_patch,
|
||||
patch_dim=self.backbone.patch_repr_dim,
|
||||
prop_dim=cond_dim,
|
||||
proj_dim=spatial_emb,
|
||||
dropout=dropout,
|
||||
@ -71,7 +69,7 @@ class VisionDiffusionMLP(nn.Module):
|
||||
visual_feature_dim = spatial_emb * num_img
|
||||
else:
|
||||
self.compress = nn.Sequential(
|
||||
nn.Linear(repr_dim, visual_feature_dim),
|
||||
nn.Linear(self.backbone.repr_dim, visual_feature_dim),
|
||||
nn.LayerNorm(visual_feature_dim),
|
||||
nn.Dropout(dropout),
|
||||
nn.ReLU(),
|
||||
|
Loading…
Reference in New Issue
Block a user