From cef44108d8eccd3daddb1d35ba1350f6dc65558a Mon Sep 17 00:00:00 2001 From: Younggyo Seo Date: Fri, 20 Jun 2025 21:52:43 -0700 Subject: [PATCH] Support MTBench (#15) This PR incorporates MTBench into the current codebase, as a good demonstration that shows how to use FastTD3 for multi-task setup. - Add support for MTBench along with its wrapper - Add support for per-task reward normalizer useful for multi-task RL, motivated by BRC paper (https://arxiv.org/abs/2505.23150v1) --- README.md | 67 +++++++- fast_td3/environments/mtbench_env.py | 148 ++++++++++++++++++ fast_td3/fast_td3.py | 50 ++++++ fast_td3/fast_td3_simbav2.py | 54 ++++++- fast_td3/fast_td3_utils.py | 206 +++++++++++++++++++++++++ fast_td3/hyperparams.py | 42 ++++- fast_td3/train.py | 131 ++++++++++++---- requirements/requirements_isaacgym.txt | 15 ++ 8 files changed, 679 insertions(+), 34 deletions(-) create mode 100644 fast_td3/environments/mtbench_env.py create mode 100644 requirements/requirements_isaacgym.txt diff --git a/README.md b/README.md index 2584f1e..78cd5d2 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,13 @@ For more information, please see our [project webpage](https://younggyo.me/fast_ ## ❗ Updates -- **[June/15/2025]** Added support for FastTD3 + SimbaV2! It's faster to train, and often achieves better asymptotic performance. +- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) -- **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot! +- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases. -- **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks. +- **[Jun/06/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot! + +- **[Jun/01/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks. ## ✨ Features @@ -81,6 +83,27 @@ cd .. pip install -r requirements/requirements.txt ``` +### Environment for MTBench +MTBench does not support humanoid experiments, but is a useful multi-task benchmark with massive parallel simulation. This could be useful for users who want to use FastTD3 for their multi-task experiments. + +```bash +conda create -n fasttd3_mtbench -y python=3.8 # Note python version +conda activate fasttd3_mtbench + +# Install IsaacGym -- recommend to follow instructions in https://github.com/BoosterRobotics/booster_gym +... + +# Install MTBench +git clone https://github.com/Viraj-Joshi/MTBench.git +cd MTbench +pip install -e . +pip install skrl +cd .. + +# Install project-specific requirements +pip install -r requirements/requirements_isaacgym.txt +``` + ### (Optional) Accelerate headless GPU rendering in cloud instances In some cloud VM images the NVIDIA kernel driver is present but the user-space OpenGL/EGL/Vulkan libraries aren't, so MuJoCo falls back to CPU renderer. You can install just the NVIDIA user-space libraries (and skip rebuilding the kernel module) with: @@ -176,6 +199,32 @@ python fast_td3/train.py \ --seed 1 ``` +### MTBench Experiments +```bash +conda activate fasttd3_mtbench +# FastTD3 +python fast_td3/train.py \ + --env_name MTBench-meta-world-v2-mt10 \ + --exp_name FastTD3 \ + --render_interval 0 \ + --seed 1 +# FastTD3 + SimbaV2 +python fast_td3/train.py \ + --env_name MTBench-meta-world-v2-mt10 \ + --exp_name FastTD3 \ + --render_interval 0 \ + --agent fasttd3_simbav2 \ + --batch_size 8192 \ + --critic_learning_rate_end 3e-5 \ + --actor_learning_rate_end 3e-5 \ + --weight_decay 0.0 \ + --critic_hidden_dim 512 \ + --critic_num_blocks 2 \ + --actor_hidden_dim 256 \ + --actor_num_blocks 1 \ + --seed 1 +``` + **Quick note:** For boolean-based arguments, you can set them to False by adding `no_` in front each argument, for instance, if you want to disable Clipped Q Learning, you can specify `--no_use_cdq` in your command. ## 💡 Performance-Related Tips @@ -315,6 +364,18 @@ Following the [LeanRL](https://github.com/pytorch-labs/LeanRL)'s recommendation, } ``` +### MTBench +```bibtex +@inproceedings{ +joshi2025benchmarking, +title={Benchmarking Massively Parallelized Multi-Task Reinforcement Learning for Robotics Tasks}, +author={Viraj Joshi and Zifan Xu and Bo Liu and Peter Stone and Amy Zhang}, +booktitle={Reinforcement Learning Conference}, +year={2025}, +url={https://openreview.net/forum?id=z0MM0y20I2} +} +``` + ### Getting SAC to Work on a Massive Parallel Simulator ```bibtex @article{raffin2025isaacsim, diff --git a/fast_td3/environments/mtbench_env.py b/fast_td3/environments/mtbench_env.py new file mode 100644 index 0000000..5e04310 --- /dev/null +++ b/fast_td3/environments/mtbench_env.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import torch +from omegaconf import OmegaConf + +import isaacgym +import isaacgymenvs + + +class MTBenchEnv: + def __init__( + self, + task_name: str, + device_id: int, + num_envs: int, + seed: int, + ): + # NOTE: Currently, we only support Meta-World-v2 MT-10/MT-50 in MTBench + task_config = MTBENCH_MW2_CONFIG.copy() + if task_name == "meta-world-v2-mt10": + # MT-10 Setup + assert num_envs == 4096, "MT-10 only supports 4096 environments (for now)" + self.num_tasks = 10 + task_config["env"]["tasks"] = [4, 16, 17, 18, 28, 31, 38, 40, 48, 49] + task_config["env"]["taskEnvCount"] = [410] * 6 + [409] * 4 + elif task_name == "meta-world-v2-mt50": + # MT-50 Setup + self.num_tasks = 50 + assert num_envs == 8192, "MT-50 only supports 8192 environments (for now)" + task_config["env"]["tasks"] = list(range(50)) + task_config["env"]["taskEnvCount"] = [164] * 42 + [163] * 8 # 6888 + 1304 + else: + raise ValueError(f"Unsupported task name: {task_name}") + task_config["env"]["numEnvs"] = num_envs + task_config["env"]["numObservations"] = 39 + self.num_tasks + task_config["env"]["seed"] = seed + + # Convert dictionary to OmegaConf object + env_cfg = {"task": task_config} + env_cfg = OmegaConf.create(env_cfg) + + self.env = isaacgymenvs.make( + task=env_cfg.task.name, + num_envs=num_envs, + sim_device=f"cuda:{device_id}", + rl_device=f"cuda:{device_id}", + seed=seed, + headless=True, + cfg=env_cfg, + ) + + self.num_envs = num_envs + self.asymmetric_obs = False + self.num_obs = self.env.observation_space.shape[0] + assert ( + self.num_obs == 39 + self.num_tasks + ), "MTBench observation space is 39 + num_tasks (one-hot vector)" + self.num_privileged_obs = 0 + self.num_actions = self.env.action_space.shape[0] + self.max_episode_steps = self.env.max_episode_length + + def reset(self) -> torch.Tensor: + """Reset the environment.""" + # TODO: Check if we need no_grad and detach here + with torch.no_grad(): # do we need this? + self.env.reset_idx(torch.arange(self.num_envs, device=self.env.device)) + obs_dict = self.env.reset() + return obs_dict["obs"].detach() + + def step( + self, actions: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]: + """Step the environment.""" + assert isinstance(actions, torch.Tensor) + + # TODO: Check if we need no_grad and detach here + with torch.no_grad(): + obs_dict, rew, dones, infos = self.env.step(actions.detach()) + truncations = infos["time_outs"] + info_ret = {"time_outs": truncations.detach()} + if "episode" in infos: + info_ret["episode"] = infos["episode"] + # NOTE: There's really no way to get the raw observations from IsaacGym + # We just use the 'reset_obs' as next_obs, unfortunately. + info_ret["observations"] = {"raw": {"obs": obs_dict["obs"].detach()}} + return obs_dict["obs"].detach(), rew.detach(), dones.detach(), info_ret + + def render(self): + raise NotImplementedError( + "We don't support rendering for IsaacLab environments" + ) + + +MTBENCH_MW2_CONFIG = { + "name": "meta-world-v2", + "physics_engine": "physx", + "env": { + "numEnvs": 1, + "envSpacing": 1.5, + "episodeLength": 150, + "enableDebugVis": False, + "clipObservations": 5.0, + "clipActions": 1.0, + "aggregateMode": 3, + "actionScale": 0.01, + "resetNoise": 0.15, + "tasks": [0], + "taskEnvCount": [4096], + "init_at_random_progress": True, + "exemptedInitAtRandomProgressTasks": [], + "taskEmbedding": True, + "taskEmbeddingType": "one_hot", + "seed": 42, + "cameraRenderingInterval": 5000, + "cameraWidth": 1024, + "cameraHeight": 1024, + "sparse_reward": False, + "termination_on_success": False, + "reward_scale": 1.0, + "fixed": False, + "numObservations": None, + "numActions": 4, + }, + "enableCameraSensors": False, + "sim": { + "dt": 0.01667, + "substeps": 2, + "up_axis": "z", + "use_gpu_pipeline": True, + "gravity": [0.0, 0.0, -9.81], + "physx": { + "num_threads": 4, + "solver_type": 1, + "use_gpu": True, + "num_position_iterations": 8, + "num_velocity_iterations": 1, + "contact_offset": 0.005, + "rest_offset": 0.0, + "bounce_threshold_velocity": 0.2, + "max_depenetration_velocity": 1000.0, + "default_buffer_size_multiplier": 10.0, + "max_gpu_contact_pairs": 1048576, + "num_subscenes": 4, + "contact_collection": 0, + }, + }, + "task": {"randomize": False}, +} diff --git a/fast_td3/fast_td3.py b/fast_td3/fast_td3.py index 4805ccc..b7d4ff1 100644 --- a/fast_td3/fast_td3.py +++ b/fast_td3/fast_td3.py @@ -114,6 +114,7 @@ class Critic(nn.Module): self.register_buffer( "q_support", torch.linspace(v_min, v_max, num_atoms, device=device) ) + self.device = device def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: return self.qnet1(obs, actions), self.qnet2(obs, actions) @@ -189,6 +190,7 @@ class Actor(nn.Module): self.register_buffer("std_min", torch.as_tensor(std_min, device=device)) self.register_buffer("std_max", torch.as_tensor(std_max, device=device)) self.n_envs = num_envs + self.device = device def forward(self, obs: torch.Tensor) -> torch.Tensor: x = obs @@ -218,3 +220,51 @@ class Actor(nn.Module): noise = torch.randn_like(act) * self.noise_scales return act + noise + + +class MultiTaskActor(Actor): + def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_tasks = num_tasks + self.task_embedding_dim = task_embedding_dim + self.task_embedding = nn.Embedding( + num_tasks, task_embedding_dim, max_norm=1.0, device=self.device + ) + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + task_ids_one_hot = obs[..., -self.num_tasks :] + task_indices = torch.argmax(task_ids_one_hot, dim=1) + task_embeddings = self.task_embedding(task_indices) + obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1) + return super().forward(obs) + + +class MultiTaskCritic(Critic): + def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_tasks = num_tasks + self.task_embedding_dim = task_embedding_dim + self.task_embedding = nn.Embedding( + num_tasks, task_embedding_dim, max_norm=1.0, device=self.device + ) + + def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + task_ids_one_hot = obs[..., -self.num_tasks :] + task_indices = torch.argmax(task_ids_one_hot, dim=1) + task_embeddings = self.task_embedding(task_indices) + obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1) + return super().forward(obs, actions) + + def projection( + self, + obs: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + bootstrap: torch.Tensor, + discount: float, + ) -> torch.Tensor: + task_ids_one_hot = obs[..., -self.num_tasks :] + task_indices = torch.argmax(task_ids_one_hot, dim=1) + task_embeddings = self.task_embedding(task_indices) + obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1) + return super().projection(obs, actions, rewards, bootstrap, discount) diff --git a/fast_td3/fast_td3_simbav2.py b/fast_td3/fast_td3_simbav2.py index 3a36a27..f0b17eb 100644 --- a/fast_td3/fast_td3_simbav2.py +++ b/fast_td3/fast_td3_simbav2.py @@ -365,6 +365,7 @@ class Critic(nn.Module): self.register_buffer( "q_support", torch.linspace(v_min, v_max, num_atoms, device=device) ) + self.device = device def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: return self.qnet1(obs, actions), self.qnet2(obs, actions) @@ -449,8 +450,8 @@ class Actor(nn.Module): self.predictor = HyperTanhPolicy( hidden_dim=hidden_dim, action_dim=n_act, - scaler_init=scaler_init, - scaler_scale=scaler_scale, + scaler_init=1.0, + scaler_scale=1.0, device=device, ) @@ -462,6 +463,7 @@ class Actor(nn.Module): self.register_buffer("std_min", torch.as_tensor(std_min, device=device)) self.register_buffer("std_max", torch.as_tensor(std_max, device=device)) self.n_envs = num_envs + self.device = device def forward(self, obs: torch.Tensor) -> torch.Tensor: x = obs @@ -492,3 +494,51 @@ class Actor(nn.Module): noise = torch.randn_like(act) * self.noise_scales return act + noise + + +class MultiTaskActor(Actor): + def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_tasks = num_tasks + self.task_embedding_dim = task_embedding_dim + self.task_embedding = nn.Embedding( + num_tasks, task_embedding_dim, max_norm=1.0, device=self.device + ) + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + task_ids_one_hot = obs[..., -self.num_tasks :] + task_indices = torch.argmax(task_ids_one_hot, dim=1) + task_embeddings = self.task_embedding(task_indices) + obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1) + return super().forward(obs) + + +class MultiTaskCritic(Critic): + def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_tasks = num_tasks + self.task_embedding_dim = task_embedding_dim + self.task_embedding = nn.Embedding( + num_tasks, task_embedding_dim, max_norm=1.0, device=self.device + ) + + def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + task_ids_one_hot = obs[..., -self.num_tasks :] + task_indices = torch.argmax(task_ids_one_hot, dim=1) + task_embeddings = self.task_embedding(task_indices) + obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1) + return super().forward(obs, actions) + + def projection( + self, + obs: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + bootstrap: torch.Tensor, + discount: float, + ) -> torch.Tensor: + task_ids_one_hot = obs[..., -self.num_tasks :] + task_indices = torch.argmax(task_ids_one_hot, dim=1) + task_embeddings = self.task_embedding(task_indices) + obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1) + return super().projection(obs, actions, rewards, bootstrap, discount) diff --git a/fast_td3/fast_td3_utils.py b/fast_td3/fast_td3_utils.py index 0f997fd..336ba1c 100644 --- a/fast_td3/fast_td3_utils.py +++ b/fast_td3/fast_td3_utils.py @@ -512,6 +512,212 @@ class RewardNormalizer(nn.Module): return self._scale_reward(rewards) +class PerTaskEmpiricalNormalization(nn.Module): + """Normalize mean and variance of values based on empirical values for each task.""" + + def __init__( + self, + num_tasks: int, + shape: tuple, + device: torch.device, + eps: float = 1e-2, + until: int = None, + ): + """ + Initialize PerTaskEmpiricalNormalization module. + + Args: + num_tasks (int): The total number of tasks. + shape (int or tuple of int): Shape of input values except batch axis. + eps (float): Small value for stability. + until (int or None): If specified, learns until the sum of batch sizes + for a specific task exceeds this value. + """ + super().__init__() + if not isinstance(shape, tuple): + shape = (shape,) + self.num_tasks = num_tasks + self.shape = shape + self.eps = eps + self.until = until + self.device = device + + # Buffers now have a leading dimension for tasks + self.register_buffer("_mean", torch.zeros(num_tasks, *shape).to(device)) + self.register_buffer("_var", torch.ones(num_tasks, *shape).to(device)) + self.register_buffer("_std", torch.ones(num_tasks, *shape).to(device)) + self.register_buffer( + "count", torch.zeros(num_tasks, dtype=torch.long).to(device) + ) + + def forward( + self, x: torch.Tensor, task_ids: torch.Tensor, center: bool = True + ) -> torch.Tensor: + """ + Normalize the input tensor `x` using statistics for the given `task_ids`. + + Args: + x (torch.Tensor): Input tensor of shape [num_envs, *shape]. + task_ids (torch.Tensor): Tensor of task indices, shape [num_envs]. + center (bool): If True, center the data by subtracting the mean. + """ + if x.shape[1:] != self.shape: + raise ValueError(f"Expected input shape (*, {self.shape}), got {x.shape}") + if x.shape[0] != task_ids.shape[0]: + raise ValueError("Batch size of x and task_ids must match.") + + # Gather the stats for the tasks in the current batch + # Reshape task_ids for broadcasting: [num_envs] -> [num_envs, 1, ...] + view_shape = (task_ids.shape[0],) + (1,) * len(self.shape) + task_ids_expanded = task_ids.view(view_shape).expand_as(x) + + mean = self._mean.gather(0, task_ids_expanded) + std = self._std.gather(0, task_ids_expanded) + + if self.training: + self.update(x, task_ids) + + if center: + return (x - mean) / (std + self.eps) + else: + return x / (std + self.eps) + + @torch.jit.unused + def update(self, x: torch.Tensor, task_ids: torch.Tensor): + """Update running statistics for the tasks present in the batch.""" + unique_tasks = torch.unique(task_ids) + + for task_id in unique_tasks: + if self.until is not None and self.count[task_id] >= self.until: + continue + + # Create a mask to select data for the current task + mask = task_ids == task_id + x_task = x[mask] + batch_size = x_task.shape[0] + + if batch_size == 0: + continue + + # Update count for this task + old_count = self.count[task_id].clone() + new_count = old_count + batch_size + + # Update mean + task_mean = self._mean[task_id] + batch_mean = torch.mean(x_task, dim=0) + delta = batch_mean - task_mean + self._mean[task_id] = task_mean + (batch_size / new_count) * delta + + # Update variance using Chan's parallel algorithm + if old_count > 0: + batch_var = torch.var(x_task, dim=0, unbiased=False) + m_a = self._var[task_id] * old_count + m_b = batch_var * batch_size + M2 = m_a + m_b + (delta**2) * (old_count * batch_size / new_count) + self._var[task_id] = M2 / new_count + else: + # For the first batch of this task + self._var[task_id] = torch.var(x_task, dim=0, unbiased=False) + + self._std[task_id] = torch.sqrt(self._var[task_id]) + self.count[task_id] = new_count + + +class PerTaskRewardNormalizer(nn.Module): + def __init__( + self, + num_tasks: int, + gamma: float, + device: torch.device, + g_max: float = 10.0, + epsilon: float = 1e-8, + ): + """ + Per-task reward normalizer, motivation comes from BRC (https://arxiv.org/abs/2505.23150v1) + """ + super().__init__() + self.num_tasks = num_tasks + self.gamma = gamma + self.g_max = g_max + self.epsilon = epsilon + self.device = device + + # Per-task running estimate of the discounted return + self.register_buffer("G", torch.zeros(num_tasks, device=device)) + # Per-task running-max of the discounted return + self.register_buffer("G_r_max", torch.zeros(num_tasks, device=device)) + # Use the new per-task normalizer for the statistics of G + self.G_rms = PerTaskEmpiricalNormalization( + num_tasks=num_tasks, shape=(1,), device=device + ) + + def _scale_reward( + self, rewards: torch.Tensor, task_ids: torch.Tensor + ) -> torch.Tensor: + """ + Scales rewards using per-task statistics. + + Args: + rewards (torch.Tensor): Reward tensor, shape [num_envs]. + task_ids (torch.Tensor): Task indices, shape [num_envs]. + """ + # Gather stats for the tasks in the batch + std_for_batch = self.G_rms._std.gather(0, task_ids.unsqueeze(-1)).squeeze(-1) + g_r_max_for_batch = self.G_r_max.gather(0, task_ids) + + var_denominator = std_for_batch + self.epsilon + min_required_denominator = g_r_max_for_batch / self.g_max + denominator = torch.maximum(var_denominator, min_required_denominator) + + # Add a small epsilon to the final denominator to prevent division by zero + # in case g_r_max is also zero. + return rewards / (denominator + self.epsilon) + + def update_stats( + self, rewards: torch.Tensor, dones: torch.Tensor, task_ids: torch.Tensor + ): + """ + Updates the running discounted return and its statistics for each task. + + Args: + rewards (torch.Tensor): Reward tensor, shape [num_envs]. + dones (torch.Tensor): Done tensor, shape [num_envs]. + task_ids (torch.Tensor): Task indices, shape [num_envs]. + """ + if not (rewards.shape == dones.shape == task_ids.shape): + raise ValueError("rewards, dones, and task_ids must have the same shape.") + + # === Update G (running discounted return) === + # Gather the previous G values for the tasks in the batch + prev_G = self.G.gather(0, task_ids) + # Update G for each environment based on its own reward and done signal + new_G = self.gamma * (1 - dones.float()) * prev_G + rewards + # Scatter the updated G values back to the main buffer + self.G.scatter_(0, task_ids, new_G) + + # === Update G_rms (statistics of G) === + # The update function handles the per-task logic internally + self.G_rms.update(new_G.unsqueeze(-1), task_ids) + + # === Update G_r_max (running max of |G|) === + prev_G_r_max = self.G_r_max.gather(0, task_ids) + # Update the max for each environment + updated_G_r_max = torch.maximum(prev_G_r_max, torch.abs(new_G)) + # Scatter the new maxes back to the main buffer + self.G_r_max.scatter_(0, task_ids, updated_G_r_max) + + def forward(self, rewards: torch.Tensor, task_ids: torch.Tensor) -> torch.Tensor: + """ + Normalizes rewards. During training, it also updates the running statistics. + + Args: + rewards (torch.Tensor): Reward tensor, shape [num_envs]. + task_ids (torch.Tensor): Task indices, shape [num_envs]. + """ + return self._scale_reward(rewards, task_ids) + + def cpu_state(sd): # detach & move to host without locking the compute stream return {k: v.detach().to("cpu", non_blocking=True) for k, v in sd.items()} diff --git a/fast_td3/hyperparams.py b/fast_td3/hyperparams.py index f8fc824..c42bef1 100644 --- a/fast_td3/hyperparams.py +++ b/fast_td3/hyperparams.py @@ -95,7 +95,7 @@ class BaseArgs: obs_normalization: bool = True """whether to enable observation normalization""" reward_normalization: bool = False - """whether to enable reward normalization (Not recommended for now, it's unstable.)""" + """whether to enable reward normalization""" max_grad_norm: float = 0.0 """the maximum gradient norm""" amp: bool = True @@ -113,6 +113,8 @@ class BaseArgs: """(Playground only) Use tuned reward for G1""" action_bounds: float = 1.0 """(IsaacLab only) the bounds of the action space (-action_bounds, action_bounds)""" + task_embedding_dim: int = 32 + """the dimension of the task embedding""" weight_decay: float = 0.1 """the weight decay of the optimizer""" @@ -169,6 +171,9 @@ def get_args(): "Isaac-Velocity-Rough-G1-v0": IsaacVelocityRoughG1Args, "Isaac-Repose-Cube-Allegro-Direct-v0": IsaacReposeCubeAllegroDirectArgs, "Isaac-Repose-Cube-Shadow-Direct-v0": IsaacReposeCubeShadowDirectArgs, + # MTBench + "MTBench-meta-world-v2-mt10": MetaWorldMT10Args, + "MTBench-meta-world-v2-mt50": MetaWorldMT50Args, } # If the provided env_name has a specific Args class, use it if base_args.env_name in env_to_args_class: @@ -183,6 +188,9 @@ def get_args(): elif base_args.env_name.startswith("Isaac-"): # IsaacLab specific_args = tyro.cli(IsaacLabArgs) + elif base_args.env_name.startswith("MTBench-"): + # MTBench + specific_args = tyro.cli(MTBenchArgs) else: # MuJoCo Playground specific_args = tyro.cli(MuJoCoPlaygroundArgs) @@ -280,6 +288,38 @@ class MuJoCoPlaygroundArgs(BaseArgs): gamma: float = 0.97 +@dataclass +class MTBenchArgs(BaseArgs): + # Default hyperparameters for MTBench + reward_normalization: bool = True + v_min: float = -10.0 + v_max: float = 10.0 + buffer_size: int = 2048 # 2K is usually enough for MTBench + num_envs: int = 4096 + num_eval_envs: int = 4096 + gamma: float = 0.99 + num_steps: int = 8 + + +@dataclass +class MetaWorldMT10Args(MTBenchArgs): + # This config achieves 97 ~ 98% success rate within 10k steps (15-20 mins on A100) + env_name: str = "MTBench-meta-world-v2-mt10" + num_envs: int = 4096 + num_eval_envs: int = 4096 + num_steps: int = 8 + + +@dataclass +class MetaWorldMT50Args(MTBenchArgs): + # FastTD3 + SimbaV2 achieves >90% success rate within 20k steps (80 mins on A100) + # Performance further improves with more training steps, slowly. + env_name: str = "MTBench-meta-world-v2-mt50" + num_envs: int = 8192 + num_eval_envs: int = 8192 + num_steps: int = 8 + + @dataclass class G1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs): env_name: str = "G1JoystickFlatTerrain" diff --git a/fast_td3/train.py b/fast_td3/train.py index baf6212..e568846 100644 --- a/fast_td3/train.py +++ b/fast_td3/train.py @@ -18,17 +18,24 @@ import tqdm import wandb import numpy as np +try: + # Required for avoiding IsaacGym import error + import isaacgym +except ImportError: + pass + import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.amp import autocast, GradScaler -from tensordict import TensorDict, from_module +from tensordict import TensorDict from fast_td3_utils import ( EmpiricalNormalization, RewardNormalizer, + PerTaskRewardNormalizer, SimpleReplayBuffer, save_params, ) @@ -103,6 +110,14 @@ def main(): ) eval_envs = envs render_env = envs + elif args.env_name.startswith("MTBench-"): + from environments.mtbench_env import MTBenchEnv + + env_name = "-".join(args.env_name.split("-")[1:]) + env_type = "mtbench" + envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed) + eval_envs = envs + render_env = envs else: from environments.mujoco_playground_env import make_env @@ -141,9 +156,19 @@ def main(): critic_obs_normalizer = nn.Identity() if args.reward_normalization: - reward_normalizer = RewardNormalizer( - gamma=args.gamma, device=device, g_max=min(abs(args.v_min), abs(args.v_max)) - ) + if env_type in ["mtbench"]: + reward_normalizer = PerTaskRewardNormalizer( + num_tasks=envs.num_tasks, + gamma=args.gamma, + device=device, + g_max=min(abs(args.v_min), abs(args.v_max)), + ) + else: + reward_normalizer = RewardNormalizer( + gamma=args.gamma, + device=device, + g_max=min(abs(args.v_min), abs(args.v_max)), + ) else: reward_normalizer = nn.Identity() @@ -165,12 +190,37 @@ def main(): "device": device, } + if env_type == "mtbench": + actor_kwargs["n_obs"] = n_obs - envs.num_tasks + args.task_embedding_dim + critic_kwargs["n_obs"] = n_critic_obs - envs.num_tasks + args.task_embedding_dim + actor_kwargs["num_tasks"] = envs.num_tasks + actor_kwargs["task_embedding_dim"] = args.task_embedding_dim + critic_kwargs["num_tasks"] = envs.num_tasks + critic_kwargs["task_embedding_dim"] = args.task_embedding_dim + if args.agent == "fasttd3": - from fast_td3 import Actor, Critic + if env_type in ["mtbench"]: + from fast_td3 import MultiTaskActor, MultiTaskCritic + + actor_cls = MultiTaskActor + critic_cls = MultiTaskCritic + else: + from fast_td3 import Actor, Critic + + actor_cls = Actor + critic_cls = Critic print("Using FastTD3") elif args.agent == "fasttd3_simbav2": - from fast_td3_simbav2 import Actor, Critic + if env_type in ["mtbench"]: + from fast_td3_simbav2 import MultiTaskActor, MultiTaskCritic + + actor_cls = MultiTaskActor + critic_cls = MultiTaskCritic + else: + from fast_td3_simbav2 import Actor, Critic + + actor_cls = Actor print("Using FastTD3 + SimbaV2") actor_kwargs.pop("init_scale") @@ -199,25 +249,31 @@ def main(): else: raise ValueError(f"Agent {args.agent} not supported") - actor = Actor(**actor_kwargs) - actor_detach = Actor(**actor_kwargs) + actor = actor_cls(**actor_kwargs) - # Copy params to actor_detach without grad - from_module(actor).data.to_module(actor_detach) - policy = actor_detach.explore + if env_type in ["mtbench"]: + # Python 3.8 doesn't support 'from_module' in tensordict + policy = actor.explore + else: + from tensordict import from_module - qnet = Critic(**critic_kwargs) - qnet_target = Critic(**critic_kwargs) + actor_detach = actor_cls(**actor_kwargs) + # Copy params to actor_detach without grad + from_module(actor).data.to_module(actor_detach) + policy = actor_detach.explore + + qnet = critic_cls(**critic_kwargs) + qnet_target = critic_cls(**critic_kwargs) qnet_target.load_state_dict(qnet.state_dict()) q_optimizer = optim.AdamW( list(qnet.parameters()), - lr=args.critic_learning_rate, + lr=torch.tensor(args.critic_learning_rate, device=device), weight_decay=args.weight_decay, ) actor_optimizer = optim.AdamW( list(actor.parameters()), - lr=args.actor_learning_rate, + lr=torch.tensor(args.actor_learning_rate, device=device), weight_decay=args.weight_decay, ) @@ -225,12 +281,12 @@ def main(): q_scheduler = optim.lr_scheduler.CosineAnnealingLR( q_optimizer, T_max=args.total_timesteps, - eta_min=args.critic_learning_rate_end, # Decay to 10% of initial lr + eta_min=torch.tensor(args.critic_learning_rate_end, device=device), ) actor_scheduler = optim.lr_scheduler.CosineAnnealingLR( actor_optimizer, T_max=args.total_timesteps, - eta_min=args.actor_learning_rate_end, # Decay to 10% of initial lr + eta_min=torch.tensor(args.actor_learning_rate_end, device=device), ) rb = SimpleReplayBuffer( @@ -262,20 +318,28 @@ def main(): obs = eval_envs.reset() # Run for a fixed number of steps - for _ in range(eval_envs.max_episode_steps): + for i in range(eval_envs.max_episode_steps): with torch.no_grad(), autocast( device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled ): obs = normalize_obs(obs) actions = actor(obs) - next_obs, rewards, dones, _ = eval_envs.step(actions.float()) + next_obs, rewards, dones, infos = eval_envs.step(actions.float()) + + if env_type == "mtbench": + # We only report success rate in MTBench evaluation + rewards = ( + infos["episode"]["success"].float() if "episode" in infos else 0.0 + ) episode_returns = torch.where( ~done_masks, episode_returns + rewards, episode_returns ) episode_lengths = torch.where( ~done_masks, episode_lengths + 1, episode_lengths ) + if env_type == "mtbench" and "episode" in infos: + dones = dones | infos["episode"]["success"] done_masks = torch.logical_or(done_masks, dones) if done_masks.all(): break @@ -291,9 +355,9 @@ def main(): if env_type == "humanoid_bench": obs = render_env.reset() renders = [render_env.render()] - elif env_type == "isaaclab": + elif env_type in ["isaaclab", "mtbench"]: raise NotImplementedError( - "We don't support rendering for IsaacLab environments" + "We don't support rendering for IsaacLab and MTBench environments" ) else: obs = render_env.reset() @@ -399,6 +463,7 @@ def main(): ) scaler.step(q_optimizer) scaler.update() + q_scheduler.step() logs_dict["critic_grad_norm"] = critic_grad_norm.detach() logs_dict["qf_loss"] = qf_loss.detach() @@ -434,6 +499,7 @@ def main(): ) scaler.step(actor_optimizer) scaler.update() + actor_scheduler.step() logs_dict["actor_grad_norm"] = actor_grad_norm.detach() logs_dict["actor_loss"] = actor_loss.detach() return logs_dict @@ -500,7 +566,12 @@ def main(): truncations = infos["time_outs"] if args.reward_normalization: - update_stats(rewards, dones.float()) + if env_type == "mtbench": + task_ids_one_hot = obs[..., -envs.num_tasks :] + task_indices = torch.argmax(task_ids_one_hot, dim=1) + update_stats(rewards, dones.float(), task_ids=task_indices) + else: + update_stats(rewards, dones.float()) if envs.asymmetric_obs: next_critic_obs = infos["observations"]["critic"] @@ -550,7 +621,15 @@ def main(): data["next"]["observations"] ) raw_rewards = data["next"]["rewards"] - data["next"]["rewards"] = normalize_reward(raw_rewards) + if env_type in ["mtbench"] and args.reward_normalization: + # Multi-task reward normalization + task_ids_one_hot = data["observations"][..., -envs.num_tasks :] + task_indices = torch.argmax(task_ids_one_hot, dim=1) + data["next"]["rewards"] = normalize_reward( + raw_rewards, task_ids=task_indices + ) + else: + data["next"]["rewards"] = normalize_reward(raw_rewards) if envs.asymmetric_obs: data["critic_observations"] = normalize_critic_obs( data["critic_observations"] @@ -591,7 +670,7 @@ def main(): if args.eval_interval > 0 and global_step % args.eval_interval == 0: print(f"Evaluating at global step {global_step}") eval_avg_return, eval_avg_length = evaluate() - if env_type in ["humanoid_bench", "isaaclab"]: + if env_type in ["humanoid_bench", "isaaclab", "mtbench"]: # NOTE: Hacky way of evaluating performance, but just works obs = envs.reset() logs["eval_avg_return"] = eval_avg_return @@ -644,10 +723,6 @@ def main(): f"models/{run_name}_{global_step}.pt", ) - # Update learning rates - q_scheduler.step() - actor_scheduler.step() - global_step += 1 pbar.update(1) diff --git a/requirements/requirements_isaacgym.txt b/requirements/requirements_isaacgym.txt new file mode 100644 index 0000000..0e3c6e7 --- /dev/null +++ b/requirements/requirements_isaacgym.txt @@ -0,0 +1,15 @@ +gymnasium<1.0.0 +jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" +matplotlib +moviepy +numpy<2.0 +pandas +protobuf +pygame +stable-baselines3 +tqdm +wandb +torchrl==0.5.0 +tensordict==0.5.0 +tyro +loguru \ No newline at end of file