support varying img size

This commit is contained in:
allenzren 2024-09-16 17:55:31 -04:00
parent 64595baca9
commit 1aaa6c2302
18 changed files with 131 additions and 81 deletions

View File

@ -159,7 +159,7 @@ To use DDIM fine-tuning, set `denoising_steps=100` in pre-training and set `mode
## Adding your own dataset/environment
### Pre-training data
Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel) and `traj_lengths`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_lengths` is a 1-D array for indexing across num_total_steps.
Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel; img_h = img_w and a multiple of 8) and `traj_lengths`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_lengths` is a 1-D array for indexing across num_total_steps.
<!-- One pre-processing example can be found at [`script/process_robomimic_dataset.py`](script/process_robomimic_dataset.py). -->
<!-- **Note:** The current implementation does not support loading history observations (only using observation at the current timestep). If needed, you can modify [here](agent/dataset/sequence.py#L130-L131). -->

View File

@ -244,15 +244,16 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
.float()
.to(self.device)
}
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
)
advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0
for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1:
nextvalues = next_value
nextvalues = (
self.model.critic(obs_venv_ts)
.reshape(1, -1)
.cpu()
.numpy()
)
else:
nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - dones_trajs[t]

View File

@ -240,18 +240,16 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
for key in self.obs_dims
}
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts, no_augment=True)
.reshape(1, -1)
.cpu()
.numpy()
)
advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0
for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1:
nextvalues = next_value
nextvalues = (
self.model.critic(obs_venv_ts, no_augment=True)
.reshape(1, -1)
.cpu()
.numpy()
)
else:
nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - dones_trajs[t]

View File

@ -220,15 +220,16 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
.float()
.to(self.device)
}
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
)
advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0
for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1:
nextvalues = next_value
nextvalues = (
self.model.critic(obs_venv_ts)
.reshape(1, -1)
.cpu()
.numpy()
)
else:
nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - dones_trajs[t]
@ -241,10 +242,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
# A = delta_t + gamma*lamdba*delta_{t+1} + ...
advantages_trajs[t] = lastgaelam = (
delta
+ self.gamma
* self.gae_lambda
* nonterminal
* lastgaelam
+ self.gamma * self.gae_lambda * nonterminal * lastgaelam
)
returns_trajs = advantages_trajs + values_trajs

View File

@ -209,15 +209,16 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
.float()
.to(self.device)
}
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
)
advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0
for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1:
nextvalues = next_value
nextvalues = (
self.model.critic(obs_venv_ts)
.reshape(1, -1)
.cpu()
.numpy()
)
else:
nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - dones_trajs[t]

View File

@ -228,18 +228,16 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
for key in self.obs_dims
}
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts, no_augment=True)
.reshape(1, -1)
.cpu()
.numpy()
)
advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0
for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1:
nextvalues = next_value
nextvalues = (
self.model.critic(obs_venv_ts, no_augment=True)
.reshape(1, -1)
.cpu()
.numpy()
)
else:
nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - dones_trajs[t]

View File

@ -122,7 +122,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1
@ -146,7 +148,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1

View File

@ -101,7 +101,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1
@ -128,7 +130,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1

View File

@ -122,7 +122,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1
@ -146,7 +148,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1

View File

@ -101,7 +101,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1
@ -128,7 +130,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1

View File

@ -122,7 +122,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1
@ -146,7 +148,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1

View File

@ -101,7 +101,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1
@ -128,7 +130,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1

View File

@ -126,7 +126,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1
@ -152,7 +154,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1

View File

@ -105,7 +105,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1
@ -134,7 +136,9 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1

View File

@ -122,7 +122,6 @@ class ViTCritic(CriticObs):
cond_dim,
img_cond_steps=1,
spatial_emb=128,
patch_repr_dim=128,
dropout=0,
augment=False,
num_img=1,
@ -136,8 +135,8 @@ class ViTCritic(CriticObs):
self.img_cond_steps = img_cond_steps
if num_img > 1:
self.compress1 = SpatialEmb(
num_patch=121, # TODO: repr_dim // patch_repr_dim,
patch_dim=patch_repr_dim,
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,
@ -145,8 +144,8 @@ class ViTCritic(CriticObs):
self.compress2 = deepcopy(self.compress1)
else: # TODO: clean up
self.compress = SpatialEmb(
num_patch=121,
patch_dim=patch_repr_dim,
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,

View File

@ -32,8 +32,6 @@ class Gaussian_VisionMLP(nn.Module):
std_max=1,
spatial_emb=0,
visual_feature_dim=128,
repr_dim=96 * 96,
patch_repr_dim=128,
dropout=0,
num_img=1,
augment=False,
@ -51,8 +49,8 @@ class Gaussian_VisionMLP(nn.Module):
assert spatial_emb > 1, "this is the dimension"
if num_img > 1:
self.compress1 = SpatialEmb(
num_patch=121, # TODO: repr_dim // patch_repr_dim,
patch_dim=patch_repr_dim,
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,
@ -60,8 +58,8 @@ class Gaussian_VisionMLP(nn.Module):
self.compress2 = deepcopy(self.compress1)
else: # TODO: clean up
self.compress = SpatialEmb(
num_patch=121,
patch_dim=patch_repr_dim,
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,
@ -69,7 +67,7 @@ class Gaussian_VisionMLP(nn.Module):
visual_feature_dim = spatial_emb * num_img
else:
self.compress = nn.Sequential(
nn.Linear(repr_dim, visual_feature_dim),
nn.Linear(self.backbone.repr_dim, visual_feature_dim),
nn.LayerNorm(visual_feature_dim),
nn.Dropout(dropout),
nn.ReLU(),

View File

@ -3,12 +3,13 @@ ViT image encoder implementation from IBRL, https://github.com/hengyuan-hu/ibrl
"""
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import List
import einops
import torch
from torch import nn
from torch.nn.init import trunc_normal_
import math
@dataclass
@ -29,6 +30,8 @@ class VitEncoder(nn.Module):
obs_shape: List[int],
cfg: VitEncoderConfig,
num_channel=3,
img_h=96,
img_w=96,
):
super().__init__()
self.obs_shape = obs_shape
@ -40,8 +43,11 @@ class VitEncoder(nn.Module):
num_head=cfg.num_heads,
depth=cfg.depth,
num_channel=num_channel,
img_h=img_h,
img_w=img_w,
)
self.img_h = img_h
self.img_w = img_w
self.num_patch = self.vit.num_patches
self.patch_repr_dim = self.cfg.embed_dim
self.repr_dim = self.cfg.embed_dim * self.vit.num_patches
@ -56,11 +62,11 @@ class VitEncoder(nn.Module):
class PatchEmbed1(nn.Module):
def __init__(self, embed_dim, num_channel=3):
def __init__(self, embed_dim, num_channel=3, img_h=96, img_w=96):
super().__init__()
self.conv = nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=8)
self.num_patch = 144
self.num_patch = math.ceil(img_h / 8) * math.ceil(img_w / 8)
self.patch_dim = embed_dim
def forward(self, x: torch.Tensor):
@ -70,7 +76,7 @@ class PatchEmbed1(nn.Module):
class PatchEmbed2(nn.Module):
def __init__(self, embed_dim, use_norm, num_channel=3):
def __init__(self, embed_dim, use_norm, num_channel=3, img_h=96, img_w=96):
super().__init__()
layers = [
nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=4),
@ -80,7 +86,11 @@ class PatchEmbed2(nn.Module):
]
self.embed = nn.Sequential(*layers)
self.num_patch = 121 # TODO: specifically for 96x96 set by Hengyuan?
H1 = math.ceil((img_h - 8) / 4) + 1
W1 = math.ceil((img_w - 8) / 4) + 1
H2 = math.ceil((H1 - 3) / 2) + 1
W2 = math.ceil((W1 - 3) / 2) + 1
self.num_patch = H2 * W2
self.patch_dim = embed_dim
def forward(self, x: torch.Tensor):
@ -146,14 +156,25 @@ class MinVit(nn.Module):
num_head,
depth,
num_channel=3,
img_h=96,
img_w=96,
):
super().__init__()
if embed_style == "embed1":
self.patch_embed = PatchEmbed1(embed_dim, num_channel=num_channel)
self.patch_embed = PatchEmbed1(
embed_dim,
num_channel=num_channel,
img_h=img_h,
img_w=img_w,
)
elif embed_style == "embed2":
self.patch_embed = PatchEmbed2(
embed_dim, use_norm=embed_norm, num_channel=num_channel
embed_dim,
use_norm=embed_norm,
num_channel=num_channel,
img_h=img_h,
img_w=img_w,
)
else:
assert False
@ -233,8 +254,14 @@ def test_transformer_layer():
if __name__ == "__main__":
obs_shape = [6, 96, 96]
enc = VitEncoder([6, 96, 96], VitEncoderConfig())
obs_shape = [6, 128, 128]
enc = VitEncoder(
obs_shape,
VitEncoderConfig(),
num_channel=obs_shape[0],
img_h=obs_shape[1],
img_w=obs_shape[2],
)
print(enc)
x = torch.rand(1, *obs_shape) * 255

View File

@ -34,8 +34,6 @@ class VisionDiffusionMLP(nn.Module):
residual_style=False,
spatial_emb=0,
visual_feature_dim=128,
repr_dim=96 * 96,
patch_repr_dim=128,
dropout=0,
num_img=1,
augment=False,
@ -53,8 +51,8 @@ class VisionDiffusionMLP(nn.Module):
assert spatial_emb > 1, "this is the dimension"
if num_img > 1:
self.compress1 = SpatialEmb(
num_patch=121, # TODO: repr_dim // patch_repr_dim,
patch_dim=patch_repr_dim,
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,
@ -62,8 +60,8 @@ class VisionDiffusionMLP(nn.Module):
self.compress2 = deepcopy(self.compress1)
else: # TODO: clean up
self.compress = SpatialEmb(
num_patch=121,
patch_dim=patch_repr_dim,
num_patch=self.backbone.num_patch,
patch_dim=self.backbone.patch_repr_dim,
prop_dim=cond_dim,
proj_dim=spatial_emb,
dropout=dropout,
@ -71,7 +69,7 @@ class VisionDiffusionMLP(nn.Module):
visual_feature_dim = spatial_emb * num_img
else:
self.compress = nn.Sequential(
nn.Linear(repr_dim, visual_feature_dim),
nn.Linear(self.backbone.repr_dim, visual_feature_dim),
nn.LayerNorm(visual_feature_dim),
nn.Dropout(dropout),
nn.ReLU(),