88 lines
2.8 KiB
Python
88 lines
2.8 KiB
Python
"""
|
|
Additional implementation of the ViT image encoder from https://github.com/hengyuan-hu/ibrl/tree/main
|
|
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class SpatialEmb(nn.Module):
|
|
def __init__(self, num_patch, patch_dim, prop_dim, proj_dim, dropout):
|
|
super().__init__()
|
|
|
|
proj_in_dim = num_patch + prop_dim
|
|
num_proj = patch_dim
|
|
self.patch_dim = patch_dim
|
|
self.prop_dim = prop_dim
|
|
|
|
self.input_proj = nn.Sequential(
|
|
nn.Linear(proj_in_dim, proj_dim),
|
|
nn.LayerNorm(proj_dim),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
self.weight = nn.Parameter(torch.zeros(1, num_proj, proj_dim))
|
|
self.dropout = nn.Dropout(dropout)
|
|
nn.init.normal_(self.weight)
|
|
|
|
def extra_repr(self) -> str:
|
|
return f"weight: nn.Parameter ({self.weight.size()})"
|
|
|
|
def forward(self, feat: torch.Tensor, prop: torch.Tensor):
|
|
feat = feat.transpose(1, 2)
|
|
|
|
if self.prop_dim > 0:
|
|
repeated_prop = prop.unsqueeze(1).repeat(1, feat.size(1), 1)
|
|
feat = torch.cat((feat, repeated_prop), dim=-1)
|
|
|
|
y = self.input_proj(feat)
|
|
z = (self.weight * y).sum(1)
|
|
z = self.dropout(z)
|
|
return z
|
|
|
|
|
|
class RandomShiftsAug:
|
|
def __init__(self, pad):
|
|
self.pad = pad
|
|
|
|
def __call__(self, x):
|
|
n, c, h, w = x.size()
|
|
assert h == w
|
|
padding = tuple([self.pad] * 4)
|
|
x = nn.functional.pad(x, padding, "replicate")
|
|
eps = 1.0 / (h + 2 * self.pad)
|
|
arange = torch.linspace(
|
|
-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype
|
|
)[:h]
|
|
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
|
|
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
|
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
|
|
|
|
shift = torch.randint(
|
|
0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype
|
|
)
|
|
shift *= 2.0 / (h + 2 * self.pad)
|
|
|
|
grid = base_grid + shift
|
|
return nn.functional.grid_sample(
|
|
x, grid, padding_mode="zeros", align_corners=False
|
|
)
|
|
|
|
|
|
# test random shift
|
|
if __name__ == "__main__":
|
|
from PIL import Image
|
|
import requests
|
|
import numpy as np
|
|
|
|
image_url = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_30.jpg"
|
|
image = Image.open(requests.get(image_url, stream=True).raw)
|
|
image = image.resize((96, 96))
|
|
|
|
image = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float()
|
|
aug = RandomShiftsAug(pad=4)
|
|
image_aug = aug(image)
|
|
image_aug = image_aug.squeeze().permute(1, 2, 0).numpy()
|
|
image_aug = Image.fromarray(image_aug.astype(np.uint8))
|
|
image_aug.show()
|