diff --git a/README.md b/README.md index 0a75b8f..c0c44a7 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,8 @@ For more information, please see our [project webpage](https://younggyo.me/fast_ ## ❗ Updates +- **[June/15/2025]** Added support for FastTD3 + SimbaV2! It's faster to train, and often achieves better asymptotic performance. + - **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot! - **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks. @@ -99,21 +101,29 @@ Please see `fast_td3/hyperparams.py` for information regarding hyperparameters! ### HumanoidBench Experiments ```bash conda activate fasttd3_hb +# FastTD3 python fast_td3/train.py --env_name h1hand-hurdle-v0 --exp_name FastTD3 --render_interval 5000 --seed 1 +# FastTD3 + SimbaV2 +python fast_td3/train.py --env_name h1hand-hurdle-v0 --exp_name FastTD3 --render_interval 5000 --agent fasttd3_simbav2 --batch_size 8192 --critic_learning_rate_end 3e-5 --actor_learning_rate_end 3e-5 --weight_decay 0.0 --critic_hidden_dim 512 --critic_num_blocks 2 --actor_hidden_dim 256 --actor_num_blocks 1 --seed 1 ``` ### MuJoCo Playground Experiments ```bash conda activate fasttd3_playground +# FastTD3 python fast_td3/train.py --env_name T1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --seed 1 python fast_td3/train.py --env_name G1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --seed 1 +# FastTD3 + SimbaV2 +python fast_td3/train.py --env_name T1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --agent fasttd3_simbav2 --batch_size 8192 --critic_learning_rate_end 3e-5 --actor_learning_rate_end 3e-5 --weight_decay 0.0 --critic_hidden_dim 512 --critic_num_blocks 2 --actor_hidden_dim 256 --actor_num_blocks 1 --seed 1 ``` ### IsaacLab Experiments ```bash conda activate fasttd3_isaaclab +# FastTD3 python fast_td3/train.py --env_name Isaac-Velocity-Flat-G1-v0 --exp_name FastTD3 --render_interval 0 --seed 1 -python fast_td3/train.py --env_name Isaac-Repose-Cube-Allegro-Direct-v0 --exp_name FastTD3 --render_interval 0 --seed 1 +# FastTD3 + SimbaV2 +python fast_td3/train.py --env_name Isaac-Repose-Cube-Allegro-Direct-v0 --exp_name FastTD3 --render_interval 0 --agent fasttd3_simbav2 --batch_size 8192 --critic_learning_rate_end 3e-5 --actor_learning_rate_end 3e-5 --weight_decay 0.0 --critic_hidden_dim 512 --critic_num_blocks 2 --actor_hidden_dim 256 --actor_num_blocks 1 --seed 1 ``` **Quick note:** For boolean-based arguments, you can set them to False by adding `no_` in front each argument, for instance, if you want to disable Clipped Q Learning, you can specify `--no_use_cdq` in your command. diff --git a/fast_td3/fast_td3_deploy.py b/fast_td3/fast_td3_deploy.py index 15fa73a..c8a8f66 100644 --- a/fast_td3/fast_td3_deploy.py +++ b/fast_td3/fast_td3_deploy.py @@ -1,7 +1,10 @@ +import math + import torch import torch.nn as nn from .fast_td3_utils import EmpiricalNormalization from .fast_td3 import Actor +from .fast_td3_simbav2 import Actor as ActorSimbaV2 class Policy(nn.Module): @@ -9,12 +12,18 @@ class Policy(nn.Module): self, n_obs: int, n_act: int, - num_envs: int, - init_scale: float, - actor_hidden_dim: int, + args: dict, + agent: str = "fasttd3", ): super().__init__() - self.actor = Actor( + + self.args = args + + num_envs = args["num_envs"] + init_scale = args["init_scale"] + actor_hidden_dim = args["actor_hidden_dim"] + + actor_kwargs = dict( n_obs=n_obs, n_act=n_act, num_envs=num_envs, @@ -22,6 +31,31 @@ class Policy(nn.Module): init_scale=init_scale, hidden_dim=actor_hidden_dim, ) + + if agent == "fasttd3": + actor_cls = Actor + elif agent == "fasttd3_simbav2": + actor_cls = ActorSimbaV2 + + actor_num_blocks = args["actor_num_blocks"] + actor_kwargs.pop("init_scale") + actor_kwargs.update( + { + "scaler_init": math.sqrt(2.0 / actor_hidden_dim), + "scaler_scale": math.sqrt(2.0 / actor_hidden_dim), + "alpha_init": 1.0 / (actor_num_blocks + 1), + "alpha_scale": 1.0 / math.sqrt(actor_hidden_dim), + "expansion": 4, + "c_shift": 3.0, + "num_blocks": actor_num_blocks, + } + ) + else: + raise ValueError(f"Agent {agent} not supported") + + self.actor = actor_cls( + **actor_kwargs, + ) self.obs_normalizer = EmpiricalNormalization(shape=n_obs, device="cpu") self.actor.eval() @@ -45,17 +79,25 @@ def load_policy(checkpoint_path): ) args = torch_checkpoint["args"] - n_obs = torch_checkpoint["actor_state_dict"]["net.0.weight"].shape[-1] - n_act = torch_checkpoint["actor_state_dict"]["fc_mu.0.weight"].shape[0] + agent = args.get("agent", "fasttd3") + if agent == "fasttd3": + n_obs = torch_checkpoint["actor_state_dict"]["net.0.weight"].shape[-1] + n_act = torch_checkpoint["actor_state_dict"]["fc_mu.0.weight"].shape[0] + elif agent == "fasttd3_simbav2": + # TODO: Too hard-coded, maybe save n_obs and n_act in the checkpoint? + n_obs = ( + torch_checkpoint["actor_state_dict"]["embedder.w.w.weight"].shape[-1] - 1 + ) + n_act = torch_checkpoint["actor_state_dict"]["predictor.mean_bias"].shape[0] + else: + raise ValueError(f"Agent {agent} not supported") policy = Policy( n_obs=n_obs, n_act=n_act, - num_envs=args["num_envs"], - init_scale=args["init_scale"], - actor_hidden_dim=args["actor_hidden_dim"], + args=args, + agent=agent, ) - policy.actor.load_state_dict(torch_checkpoint["actor_state_dict"]) if len(torch_checkpoint["obs_normalizer_state"]) == 0: diff --git a/fast_td3/fast_td3_simbav2.py b/fast_td3/fast_td3_simbav2.py new file mode 100644 index 0000000..3a36a27 --- /dev/null +++ b/fast_td3/fast_td3_simbav2.py @@ -0,0 +1,494 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +def l2normalize( + tensor: torch.Tensor, axis: int = -1, eps: float = 1e-8 +) -> torch.Tensor: + """Computes L2 normalization of a tensor.""" + return tensor / (torch.linalg.norm(tensor, ord=2, dim=axis, keepdim=True) + eps) + + +class Scaler(nn.Module): + """ + A learnable scaling layer. + """ + + def __init__( + self, + dim: int, + init: float = 1.0, + scale: float = 1.0, + device: torch.device = None, + ): + super().__init__() + self.scaler = nn.Parameter(torch.full((dim,), init * scale, device=device)) + self.forward_scaler = init / scale + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.scaler * self.forward_scaler * x + + +class HyperDense(nn.Module): + """ + A dense layer without bias and with orthogonal initialization. + """ + + def __init__(self, in_dim: int, hidden_dim: int, device: torch.device = None): + super().__init__() + self.w = nn.Linear(in_dim, hidden_dim, bias=False, device=device) + nn.init.orthogonal_(self.w.weight, gain=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w(x) + + +class HyperMLP(nn.Module): + """ + A small MLP with a specific architecture using HyperDense and Scaler. + """ + + def __init__( + self, + in_dim: int, + hidden_dim: int, + out_dim: int, + scaler_init: float, + scaler_scale: float, + eps: float = 1e-8, + device: torch.device = None, + ): + super().__init__() + self.w1 = HyperDense(in_dim, hidden_dim, device=device) + self.scaler = Scaler(hidden_dim, scaler_init, scaler_scale, device=device) + self.w2 = HyperDense(hidden_dim, out_dim, device=device) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.w1(x) + x = self.scaler(x) + # `eps` is required to prevent zero vector. + x = F.relu(x) + self.eps + x = self.w2(x) + x = l2normalize(x, axis=-1) + return x + + +class HyperEmbedder(nn.Module): + """ + Embeds input by concatenating a constant, normalizing, and applying layers. + """ + + def __init__( + self, + in_dim: int, + hidden_dim: int, + scaler_init: float, + scaler_scale: float, + c_shift: float, + device: torch.device = None, + ): + super().__init__() + # The input dimension to the dense layer is in_dim + 1 + self.w = HyperDense(in_dim + 1, hidden_dim, device=device) + self.scaler = Scaler(hidden_dim, scaler_init, scaler_scale, device=device) + self.c_shift = c_shift + + def forward(self, x: torch.Tensor) -> torch.Tensor: + new_axis = torch.full((*x.shape[:-1], 1), self.c_shift, device=x.device) + x = torch.cat([x, new_axis], dim=-1) + x = l2normalize(x, axis=-1) + x = self.w(x) + x = self.scaler(x) + x = l2normalize(x, axis=-1) + return x + + +class HyperLERPBlock(nn.Module): + """ + A residual block using Linear Interpolation (LERP). + """ + + def __init__( + self, + hidden_dim: int, + scaler_init: float, + scaler_scale: float, + alpha_init: float, + alpha_scale: float, + expansion: int = 4, + device: torch.device = None, + ): + super().__init__() + self.mlp = HyperMLP( + in_dim=hidden_dim, + hidden_dim=hidden_dim * expansion, + out_dim=hidden_dim, + scaler_init=scaler_init / math.sqrt(expansion), + scaler_scale=scaler_scale / math.sqrt(expansion), + device=device, + ) + self.alpha_scaler = Scaler( + dim=hidden_dim, + init=alpha_init, + scale=alpha_scale, + device=device, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + mlp_out = self.mlp(x) + # The original paper uses (x - residual) but x is the residual here. + # This is interpreted as alpha * (mlp_output - residual_input) + x = residual + self.alpha_scaler(mlp_out - residual) + x = l2normalize(x, axis=-1) + return x + + +class HyperTanhPolicy(nn.Module): + """ + A policy that outputs a Tanh action. + """ + + def __init__( + self, + hidden_dim: int, + action_dim: int, + scaler_init: float, + scaler_scale: float, + device: torch.device = None, + ): + super().__init__() + self.mean_w1 = HyperDense(hidden_dim, hidden_dim, device=device) + self.mean_scaler = Scaler(hidden_dim, scaler_init, scaler_scale, device=device) + self.mean_w2 = HyperDense(hidden_dim, action_dim, device=device) + self.mean_bias = nn.Parameter(torch.zeros(action_dim, device=device)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Mean path + mean = self.mean_w1(x) + mean = self.mean_scaler(mean) + mean = self.mean_w2(mean) + self.mean_bias + mean = torch.tanh(mean) + return mean + + +class HyperCategoricalValue(nn.Module): + """ + A value function that predicts a categorical distribution over a range of values. + """ + + def __init__( + self, + hidden_dim: int, + num_bins: int, + scaler_init: float, + scaler_scale: float, + device: torch.device = None, + ): + super().__init__() + self.w1 = HyperDense(hidden_dim, hidden_dim, device=device) + self.scaler = Scaler(hidden_dim, scaler_init, scaler_scale, device=device) + self.w2 = HyperDense(hidden_dim, num_bins, device=device) + self.bias = nn.Parameter(torch.zeros(num_bins, device=device)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + logits = self.w1(x) + logits = self.scaler(logits) + logits = self.w2(logits) + self.bias + return logits + + +class DistributionalQNetwork(nn.Module): + def __init__( + self, + n_obs: int, + n_act: int, + num_atoms: int, + v_min: float, + v_max: float, + hidden_dim: int, + scaler_init: float, + scaler_scale: float, + alpha_init: float, + alpha_scale: float, + num_blocks: int, + c_shift: float, + expansion: int, + device: torch.device = None, + ): + super().__init__() + + self.embedder = HyperEmbedder( + in_dim=n_obs + n_act, + hidden_dim=hidden_dim, + scaler_init=scaler_init, + scaler_scale=scaler_scale, + c_shift=c_shift, + device=device, + ) + + self.encoder = nn.Sequential( + *[ + HyperLERPBlock( + hidden_dim=hidden_dim, + scaler_init=scaler_init, + scaler_scale=scaler_scale, + alpha_init=alpha_init, + alpha_scale=alpha_scale, + expansion=expansion, + device=device, + ) + for _ in range(num_blocks) + ] + ) + + self.predictor = HyperCategoricalValue( + hidden_dim=hidden_dim, + num_bins=num_atoms, + scaler_init=1.0, + scaler_scale=1.0, + device=device, + ) + self.v_min = v_min + self.v_max = v_max + self.num_atoms = num_atoms + + def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + x = torch.cat([obs, actions], 1) + x = self.embedder(x) + x = self.encoder(x) + x = self.predictor(x) + return x + + def projection( + self, + obs: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + bootstrap: torch.Tensor, + discount: float, + q_support: torch.Tensor, + device: torch.device, + ) -> torch.Tensor: + delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) + batch_size = rewards.shape[0] + + target_z = ( + rewards.unsqueeze(1) + + bootstrap.unsqueeze(1) * discount.unsqueeze(1) * q_support + ) + target_z = target_z.clamp(self.v_min, self.v_max) + b = (target_z - self.v_min) / delta_z + l = torch.floor(b).long() + u = torch.ceil(b).long() + + l_mask = torch.logical_and((u > 0), (l == u)) + u_mask = torch.logical_and((l < (self.num_atoms - 1)), (l == u)) + + l = torch.where(l_mask, l - 1, l) + u = torch.where(u_mask, u + 1, u) + + next_dist = F.softmax(self.forward(obs, actions), dim=1) + proj_dist = torch.zeros_like(next_dist) + offset = ( + torch.linspace( + 0, (batch_size - 1) * self.num_atoms, batch_size, device=device + ) + .unsqueeze(1) + .expand(batch_size, self.num_atoms) + .long() + ) + proj_dist.view(-1).index_add_( + 0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1) + ) + proj_dist.view(-1).index_add_( + 0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1) + ) + return proj_dist + + +class Critic(nn.Module): + def __init__( + self, + n_obs: int, + n_act: int, + num_atoms: int, + v_min: float, + v_max: float, + hidden_dim: int, + scaler_init: float, + scaler_scale: float, + alpha_init: float, + alpha_scale: float, + num_blocks: int, + c_shift: float, + expansion: int, + device: torch.device = None, + ): + super().__init__() + self.qnet1 = DistributionalQNetwork( + n_obs=n_obs, + n_act=n_act, + num_atoms=num_atoms, + v_min=v_min, + v_max=v_max, + scaler_init=scaler_init, + scaler_scale=scaler_scale, + alpha_init=alpha_init, + alpha_scale=alpha_scale, + num_blocks=num_blocks, + c_shift=c_shift, + expansion=expansion, + hidden_dim=hidden_dim, + device=device, + ) + self.qnet2 = DistributionalQNetwork( + n_obs=n_obs, + n_act=n_act, + num_atoms=num_atoms, + v_min=v_min, + v_max=v_max, + scaler_init=scaler_init, + scaler_scale=scaler_scale, + alpha_init=alpha_init, + alpha_scale=alpha_scale, + num_blocks=num_blocks, + c_shift=c_shift, + expansion=expansion, + hidden_dim=hidden_dim, + device=device, + ) + + self.register_buffer( + "q_support", torch.linspace(v_min, v_max, num_atoms, device=device) + ) + + def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + return self.qnet1(obs, actions), self.qnet2(obs, actions) + + def projection( + self, + obs: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + bootstrap: torch.Tensor, + discount: float, + ) -> torch.Tensor: + """Projection operation that includes q_support directly""" + q1_proj = self.qnet1.projection( + obs, + actions, + rewards, + bootstrap, + discount, + self.q_support, + self.q_support.device, + ) + q2_proj = self.qnet2.projection( + obs, + actions, + rewards, + bootstrap, + discount, + self.q_support, + self.q_support.device, + ) + return q1_proj, q2_proj + + def get_value(self, probs: torch.Tensor) -> torch.Tensor: + """Calculate value from logits using support""" + return torch.sum(probs * self.q_support, dim=1) + + +class Actor(nn.Module): + def __init__( + self, + n_obs: int, + n_act: int, + num_envs: int, + hidden_dim: int, + scaler_init: float, + scaler_scale: float, + alpha_init: float, + alpha_scale: float, + expansion: int, + c_shift: float, + num_blocks: int, + std_min: float = 0.05, + std_max: float = 0.8, + device: torch.device = None, + ): + super().__init__() + self.n_act = n_act + + self.embedder = HyperEmbedder( + in_dim=n_obs, + hidden_dim=hidden_dim, + scaler_init=scaler_init, + scaler_scale=scaler_scale, + c_shift=c_shift, + device=device, + ) + self.encoder = nn.Sequential( + *[ + HyperLERPBlock( + hidden_dim=hidden_dim, + scaler_init=scaler_init, + scaler_scale=scaler_scale, + alpha_init=alpha_init, + alpha_scale=alpha_scale, + expansion=expansion, + device=device, + ) + for _ in range(num_blocks) + ] + ) + self.predictor = HyperTanhPolicy( + hidden_dim=hidden_dim, + action_dim=n_act, + scaler_init=scaler_init, + scaler_scale=scaler_scale, + device=device, + ) + + noise_scales = ( + torch.rand(num_envs, 1, device=device) * (std_max - std_min) + std_min + ) + self.register_buffer("noise_scales", noise_scales) + + self.register_buffer("std_min", torch.as_tensor(std_min, device=device)) + self.register_buffer("std_max", torch.as_tensor(std_max, device=device)) + self.n_envs = num_envs + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + x = obs + x = self.embedder(x) + x = self.encoder(x) + x = self.predictor(x) + return x + + def explore( + self, obs: torch.Tensor, dones: torch.Tensor = None, deterministic: bool = False + ) -> torch.Tensor: + # If dones is provided, resample noise for environments that are done + if dones is not None and dones.sum() > 0: + # Generate new noise scales for done environments (one per environment) + new_scales = ( + torch.rand(self.n_envs, 1, device=obs.device) + * (self.std_max - self.std_min) + + self.std_min + ) + + # Update only the noise scales for environments that are done + dones_view = dones.view(-1, 1) > 0 + self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales) + + act = self(obs) + if deterministic: + return act + + noise = torch.randn_like(act) * self.noise_scales + return act + noise diff --git a/fast_td3/fast_td3_utils.py b/fast_td3/fast_td3_utils.py index d5c2a1d..0f997fd 100644 --- a/fast_td3/fast_td3_utils.py +++ b/fast_td3/fast_td3_utils.py @@ -1,5 +1,7 @@ import os +from typing import Optional + import torch import torch.nn as nn @@ -472,6 +474,44 @@ class EmpiricalNormalization(nn.Module): return y * (self._std + self.eps) + self._mean +class RewardNormalizer(nn.Module): + def __init__( + self, + gamma: float, + device: torch.device, + g_max: float = 10.0, + epsilon: float = 1e-8, + ): + super().__init__() + self.register_buffer( + "G", torch.zeros(1, device=device) + ) # running estimate of the discounted return + self.register_buffer("G_r_max", torch.zeros(1, device=device)) # running-max + self.G_rms = EmpiricalNormalization(shape=1, device=device) + self.gamma = gamma + self.g_max = g_max + self.epsilon = epsilon + + def _scale_reward(self, rewards: torch.Tensor) -> torch.Tensor: + var_denominator = self.G_rms.std[0] + self.epsilon + min_required_denominator = self.G_r_max / self.g_max + denominator = torch.maximum(var_denominator, min_required_denominator) + + return rewards / denominator + + def update_stats( + self, + rewards: torch.Tensor, + dones: torch.Tensor, + ): + self.G = self.gamma * (1 - dones) * self.G + rewards + self.G_rms.update(self.G.view(-1, 1)) + self.G_r_max = max(self.G_r_max, max(abs(self.G))) + + def forward(self, rewards: torch.Tensor) -> torch.Tensor: + return self._scale_reward(rewards) + + def cpu_state(sd): # detach & move to host without locking the compute stream return {k: v.detach().to("cpu", non_blocking=True) for k, v in sd.items()} diff --git a/fast_td3/hyperparams.py b/fast_td3/hyperparams.py index c1fa4df..f8fc824 100644 --- a/fast_td3/hyperparams.py +++ b/fast_td3/hyperparams.py @@ -10,6 +10,8 @@ class BaseArgs: # See IsaacLabArgs for default hyperparameters for IsaacLab env_name: str = "h1hand-stand-v0" """the id of the environment""" + agent: str = "fasttd3" + """the agent to use: currently support [fasttd3, fasttd3_simbav2]""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True @@ -36,6 +38,10 @@ class BaseArgs: """the learning rate of the critic""" actor_learning_rate: float = 3e-4 """the learning rate for the actor""" + critic_learning_rate_end: float = 3e-4 + """the learning rate of the critic at the end of training""" + actor_learning_rate_end: float = 3e-4 + """the learning rate for the actor at the end of training""" buffer_size: int = 1024 * 50 """the replay memory buffer size""" num_steps: int = 1 @@ -72,6 +78,10 @@ class BaseArgs: """the hidden dimension of the critic network""" actor_hidden_dim: int = 512 """the hidden dimension of the actor network""" + critic_num_blocks: int = 2 + """(SimbaV2 only) the number of blocks in the critic network""" + actor_num_blocks: int = 1 + """(SimbaV2 only) the number of blocks in the actor network""" use_cdq: bool = True """whether to use Clipped Double Q-learning""" measure_burnin: int = 3 @@ -84,6 +94,8 @@ class BaseArgs: """whether to use torch.compile.""" obs_normalization: bool = True """whether to enable observation normalization""" + reward_normalization: bool = False + """whether to enable reward normalization (Not recommended for now, it's unstable.)""" max_grad_norm: float = 0.0 """the maximum gradient norm""" amp: bool = True @@ -350,6 +362,7 @@ class Go1GetupArgs(MuJoCoPlaygroundArgs): class LeapCubeReorientArgs(MuJoCoPlaygroundArgs): env_name: str = "LeapCubeReorient" num_steps: int = 3 + gamma: float = 0.99 policy_noise: float = 0.2 v_min: float = -50.0 v_max: float = 50.0 @@ -361,6 +374,7 @@ class LeapCubeRotateZAxisArgs(MuJoCoPlaygroundArgs): env_name: str = "LeapCubeRotateZAxis" num_steps: int = 1 policy_noise: float = 0.2 + gamma: float = 0.99 v_min: float = -10.0 v_max: float = 10.0 use_cdq: bool = False diff --git a/fast_td3/train.py b/fast_td3/train.py index 672ab79..baf6212 100644 --- a/fast_td3/train.py +++ b/fast_td3/train.py @@ -12,6 +12,7 @@ os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest" import random import time +import math import tqdm import wandb @@ -25,9 +26,13 @@ from torch.amp import autocast, GradScaler from tensordict import TensorDict, from_module -from fast_td3_utils import EmpiricalNormalization, SimpleReplayBuffer, save_params +from fast_td3_utils import ( + EmpiricalNormalization, + RewardNormalizer, + SimpleReplayBuffer, + save_params, +) from hyperparams import get_args -from fast_td3 import Actor, Critic torch.set_float32_matmul_precision("high") @@ -135,44 +140,74 @@ def main(): obs_normalizer = nn.Identity() critic_obs_normalizer = nn.Identity() - actor = Actor( - n_obs=n_obs, - n_act=n_act, - num_envs=args.num_envs, - device=device, - init_scale=args.init_scale, - hidden_dim=args.actor_hidden_dim, - ) - actor_detach = Actor( - n_obs=n_obs, - n_act=n_act, - num_envs=args.num_envs, - device=device, - init_scale=args.init_scale, - hidden_dim=args.actor_hidden_dim, - ) + if args.reward_normalization: + reward_normalizer = RewardNormalizer( + gamma=args.gamma, device=device, g_max=min(abs(args.v_min), abs(args.v_max)) + ) + else: + reward_normalizer = nn.Identity() + + actor_kwargs = { + "n_obs": n_obs, + "n_act": n_act, + "num_envs": args.num_envs, + "device": device, + "init_scale": args.init_scale, + "hidden_dim": args.actor_hidden_dim, + } + critic_kwargs = { + "n_obs": n_critic_obs, + "n_act": n_act, + "num_atoms": args.num_atoms, + "v_min": args.v_min, + "v_max": args.v_max, + "hidden_dim": args.critic_hidden_dim, + "device": device, + } + + if args.agent == "fasttd3": + from fast_td3 import Actor, Critic + + print("Using FastTD3") + elif args.agent == "fasttd3_simbav2": + from fast_td3_simbav2 import Actor, Critic + + print("Using FastTD3 + SimbaV2") + actor_kwargs.pop("init_scale") + actor_kwargs.update( + { + "scaler_init": math.sqrt(2.0 / args.actor_hidden_dim), + "scaler_scale": math.sqrt(2.0 / args.actor_hidden_dim), + "alpha_init": 1.0 / (args.actor_num_blocks + 1), + "alpha_scale": 1.0 / math.sqrt(args.actor_hidden_dim), + "expansion": 4, + "c_shift": 3.0, + "num_blocks": args.actor_num_blocks, + } + ) + critic_kwargs.update( + { + "scaler_init": math.sqrt(2.0 / args.critic_hidden_dim), + "scaler_scale": math.sqrt(2.0 / args.critic_hidden_dim), + "alpha_init": 1.0 / (args.critic_num_blocks + 1), + "alpha_scale": 1.0 / math.sqrt(args.critic_hidden_dim), + "num_blocks": args.critic_num_blocks, + "expansion": 4, + "c_shift": 3.0, + } + ) + else: + raise ValueError(f"Agent {args.agent} not supported") + + actor = Actor(**actor_kwargs) + actor_detach = Actor(**actor_kwargs) + # Copy params to actor_detach without grad from_module(actor).data.to_module(actor_detach) policy = actor_detach.explore - qnet = Critic( - n_obs=n_critic_obs, - n_act=n_act, - num_atoms=args.num_atoms, - v_min=args.v_min, - v_max=args.v_max, - hidden_dim=args.critic_hidden_dim, - device=device, - ) - qnet_target = Critic( - n_obs=n_critic_obs, - n_act=n_act, - num_atoms=args.num_atoms, - v_min=args.v_min, - v_max=args.v_max, - hidden_dim=args.critic_hidden_dim, - device=device, - ) + qnet = Critic(**critic_kwargs) + qnet_target = Critic(**critic_kwargs) qnet_target.load_state_dict(qnet.state_dict()) q_optimizer = optim.AdamW( @@ -186,6 +221,18 @@ def main(): weight_decay=args.weight_decay, ) + # Add learning rate schedulers + q_scheduler = optim.lr_scheduler.CosineAnnealingLR( + q_optimizer, + T_max=args.total_timesteps, + eta_min=args.critic_learning_rate_end, # Decay to 10% of initial lr + ) + actor_scheduler = optim.lr_scheduler.CosineAnnealingLR( + actor_optimizer, + T_max=args.total_timesteps, + eta_min=args.actor_learning_rate_end, # Decay to 10% of initial lr + ) + rb = SimpleReplayBuffer( n_env=args.num_envs, buffer_size=args.buffer_size, @@ -353,7 +400,6 @@ def main(): scaler.step(q_optimizer) scaler.update() - logs_dict["buffer_rewards"] = rewards.mean() logs_dict["critic_grad_norm"] = critic_grad_norm.detach() logs_dict["qf_loss"] = qf_loss.detach() logs_dict["qf_max"] = qf1_next_target_value.max().detach() @@ -399,9 +445,15 @@ def main(): policy = torch.compile(policy, mode=mode) normalize_obs = torch.compile(obs_normalizer.forward, mode=mode) normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode) + if args.reward_normalization: + update_stats = torch.compile(reward_normalizer.update_stats, mode=mode) + normalize_reward = torch.compile(reward_normalizer.forward, mode=mode) else: normalize_obs = obs_normalizer.forward normalize_critic_obs = critic_obs_normalizer.forward + if args.reward_normalization: + update_stats = reward_normalizer.update_stats + normalize_reward = reward_normalizer.forward if envs.asymmetric_obs: obs, critic_obs = envs.reset_with_critic_obs() @@ -447,6 +499,9 @@ def main(): next_obs, rewards, dones, infos = envs.step(actions.float()) truncations = infos["time_outs"] + if args.reward_normalization: + update_stats(rewards, dones.float()) + if envs.asymmetric_obs: next_critic_obs = infos["observations"]["critic"] @@ -494,6 +549,8 @@ def main(): data["next"]["observations"] = normalize_obs( data["next"]["observations"] ) + raw_rewards = data["next"]["rewards"] + data["next"]["rewards"] = normalize_reward(raw_rewards) if envs.asymmetric_obs: data["critic_observations"] = normalize_critic_obs( data["critic_observations"] @@ -527,8 +584,8 @@ def main(): "qf_min": logs_dict["qf_min"].mean(), "actor_grad_norm": logs_dict["actor_grad_norm"].mean(), "critic_grad_norm": logs_dict["critic_grad_norm"].mean(), - "buffer_rewards": logs_dict["buffer_rewards"].mean(), "env_rewards": rewards.mean(), + "buffer_rewards": raw_rewards.mean(), } if args.eval_interval > 0 and global_step % args.eval_interval == 0: @@ -563,6 +620,8 @@ def main(): { "speed": speed, "frame": global_step * args.num_envs, + "critic_lr": q_scheduler.get_last_lr()[0], + "actor_lr": actor_scheduler.get_last_lr()[0], **logs, }, step=global_step, @@ -585,6 +644,10 @@ def main(): f"models/{run_name}_{global_step}.pt", ) + # Update learning rates + q_scheduler.step() + actor_scheduler.step() + global_step += 1 pbar.update(1)