""" ViT image encoder implementation from IBRL, https://github.com/hengyuan-hu/ibrl """ from dataclasses import dataclass, field from typing import List import einops import torch from torch import nn from torch.nn.init import trunc_normal_ @dataclass class VitEncoderConfig: patch_size: int = 8 depth: int = 1 embed_dim: int = 128 num_heads: int = 4 act_layer = nn.GELU stride: int = -1 embed_style: str = "embed2" embed_norm: int = 0 class VitEncoder(nn.Module): def __init__(self, obs_shape: List[int], cfg: VitEncoderConfig): super().__init__() self.obs_shape = obs_shape self.cfg = cfg self.vit = MinVit( embed_style=cfg.embed_style, embed_dim=cfg.embed_dim, embed_norm=cfg.embed_norm, num_head=cfg.num_heads, depth=cfg.depth, ) self.num_patch = self.vit.num_patches self.patch_repr_dim = self.cfg.embed_dim self.repr_dim = self.cfg.embed_dim * self.vit.num_patches def forward(self, obs, flatten=False) -> torch.Tensor: # assert obs.max() > 5 obs = obs / 255.0 - 0.5 feats: torch.Tensor = self.vit.forward(obs) if flatten: feats = feats.flatten(1, 2) return feats class PatchEmbed1(nn.Module): def __init__(self, embed_dim): super().__init__() self.conv = nn.Conv2d(3, embed_dim, kernel_size=8, stride=8) self.num_patch = 144 self.patch_dim = embed_dim def forward(self, x: torch.Tensor): y = self.conv(x) y = einops.rearrange(y, "b c h w -> b (h w) c") return y class PatchEmbed2(nn.Module): def __init__(self, embed_dim, use_norm): super().__init__() layers = [ nn.Conv2d(3, embed_dim, kernel_size=8, stride=4), nn.GroupNorm(embed_dim, embed_dim) if use_norm else nn.Identity(), nn.ReLU(), nn.Conv2d(embed_dim, embed_dim, kernel_size=3, stride=2), ] self.embed = nn.Sequential(*layers) self.num_patch = 121 # TODO: specifically for 96x96 set by Hengyuan? self.patch_dim = embed_dim def forward(self, x: torch.Tensor): y = self.embed(x) y = einops.rearrange(y, "b c h w -> b (h w) c") return y class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_head): super().__init__() assert embed_dim % num_head == 0 self.num_head = num_head self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) def forward(self, x, attn_mask): """ x: [batch, seq, embed_dim] """ qkv = self.qkv_proj(x) q, k, v = einops.rearrange( qkv, "b t (k h d) -> b k h t d", k=3, h=self.num_head ).unbind(1) # force flash/mem-eff attention, it will raise error if flash cannot be applied with torch.backends.cuda.sdp_kernel(enable_math=False): attn_v = torch.nn.functional.scaled_dot_product_attention( q, k, v, dropout_p=0.0, attn_mask=attn_mask ) attn_v = einops.rearrange(attn_v, "b h t d -> b t (h d)") return self.out_proj(attn_v) class TransformerLayer(nn.Module): def __init__(self, embed_dim, num_head, dropout): super().__init__() self.layer_norm1 = nn.LayerNorm(embed_dim) self.mha = MultiHeadAttention(embed_dim, num_head) self.layer_norm2 = nn.LayerNorm(embed_dim) self.linear1 = nn.Linear(embed_dim, 4 * embed_dim) self.linear2 = nn.Linear(4 * embed_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x, attn_mask=None): x = x + self.dropout(self.mha(self.layer_norm1(x), attn_mask)) x = x + self.dropout(self._ff_block(self.layer_norm2(x))) return x def _ff_block(self, x): x = self.linear2(nn.functional.gelu(self.linear1(x))) return x class MinVit(nn.Module): def __init__(self, embed_style, embed_dim, embed_norm, num_head, depth): super().__init__() if embed_style == "embed1": self.patch_embed = PatchEmbed1(embed_dim) elif embed_style == "embed2": self.patch_embed = PatchEmbed2(embed_dim, use_norm=embed_norm) else: assert False self.pos_embed = nn.Parameter( torch.zeros(1, self.patch_embed.num_patch, embed_dim) ) layers = [ TransformerLayer(embed_dim, num_head, dropout=0) for _ in range(depth) ] self.net = nn.Sequential(*layers) self.norm = nn.LayerNorm(embed_dim) self.num_patches = self.patch_embed.num_patch # weight init trunc_normal_(self.pos_embed, std=0.02) named_apply(init_weights_vit_timm, self) def forward(self, x): x = self.patch_embed(x) x = x + self.pos_embed x = self.net(x) return self.norm(x) def init_weights_vit_timm(module: nn.Module, name: str = ""): """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def named_apply( fn, module: nn.Module, name="", depth_first=True, include_root=False ) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply( fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True, ) if depth_first and include_root: fn(module=module, name=name) return module def test_patch_embed(): print("embed 1") embed = PatchEmbed1(128) x = torch.rand(10, 3, 96, 96) y = embed(x) print(y.size()) print("embed 2") embed = PatchEmbed2(128, True) x = torch.rand(10, 3, 96, 96) y = embed(x) print(y.size()) def test_transformer_layer(): embed = PatchEmbed1(128) x = torch.rand(10, 3, 96, 96) y = embed(x) print(y.size()) transformer = TransformerLayer(128, 4, False, 0) z = transformer(y) print(z.size()) if __name__ == "__main__": import rich.traceback import pyrallis @dataclass class MainConfig: net_type: str = "vit" obs_shape: list[int] = field(default_factory=lambda: [3, 96, 96]) vit: VitEncoderConfig = field(default_factory=lambda: VitEncoderConfig()) rich.traceback.install() cfg = pyrallis.parse(config_class=MainConfig) # type: ignore enc = VitEncoder(cfg.obs_shape, cfg.vit) print(enc) x = torch.rand(1, *cfg.obs_shape) * 255 print("output size:", enc(x, flatten=False).size()) print("repr dim:", enc.repr_dim, ", real dim:", enc(x, flatten=True).size())