Support MTBench (#15)

This PR incorporates MTBench into the current codebase, as a good demonstration that shows how to use FastTD3 for multi-task setup.

- Add support for MTBench along with its wrapper
- Add support for per-task reward normalizer useful for multi-task RL, motivated by BRC paper (https://arxiv.org/abs/2505.23150v1)
This commit is contained in:
Younggyo Seo 2025-06-20 21:52:43 -07:00 committed by GitHub
parent 3facede77d
commit cef44108d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 679 additions and 34 deletions

View File

@ -10,11 +10,13 @@ For more information, please see our [project webpage](https://younggyo.me/fast_
## ❗ Updates
- **[June/15/2025]** Added support for FastTD3 + SimbaV2! It's faster to train, and often achieves better asymptotic performance.
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench)
- **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases.
- **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
- **[Jun/06/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
- **[Jun/01/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
## ✨ Features
@ -81,6 +83,27 @@ cd ..
pip install -r requirements/requirements.txt
```
### Environment for MTBench
MTBench does not support humanoid experiments, but is a useful multi-task benchmark with massive parallel simulation. This could be useful for users who want to use FastTD3 for their multi-task experiments.
```bash
conda create -n fasttd3_mtbench -y python=3.8 # Note python version
conda activate fasttd3_mtbench
# Install IsaacGym -- recommend to follow instructions in https://github.com/BoosterRobotics/booster_gym
...
# Install MTBench
git clone https://github.com/Viraj-Joshi/MTBench.git
cd MTbench
pip install -e .
pip install skrl
cd ..
# Install project-specific requirements
pip install -r requirements/requirements_isaacgym.txt
```
### (Optional) Accelerate headless GPU rendering in cloud instances
In some cloud VM images the NVIDIA kernel driver is present but the user-space OpenGL/EGL/Vulkan libraries aren't, so MuJoCo falls back to CPU renderer. You can install just the NVIDIA user-space libraries (and skip rebuilding the kernel module) with:
@ -176,6 +199,32 @@ python fast_td3/train.py \
--seed 1
```
### MTBench Experiments
```bash
conda activate fasttd3_mtbench
# FastTD3
python fast_td3/train.py \
--env_name MTBench-meta-world-v2-mt10 \
--exp_name FastTD3 \
--render_interval 0 \
--seed 1
# FastTD3 + SimbaV2
python fast_td3/train.py \
--env_name MTBench-meta-world-v2-mt10 \
--exp_name FastTD3 \
--render_interval 0 \
--agent fasttd3_simbav2 \
--batch_size 8192 \
--critic_learning_rate_end 3e-5 \
--actor_learning_rate_end 3e-5 \
--weight_decay 0.0 \
--critic_hidden_dim 512 \
--critic_num_blocks 2 \
--actor_hidden_dim 256 \
--actor_num_blocks 1 \
--seed 1
```
**Quick note:** For boolean-based arguments, you can set them to False by adding `no_` in front each argument, for instance, if you want to disable Clipped Q Learning, you can specify `--no_use_cdq` in your command.
## 💡 Performance-Related Tips
@ -315,6 +364,18 @@ Following the [LeanRL](https://github.com/pytorch-labs/LeanRL)'s recommendation,
}
```
### MTBench
```bibtex
@inproceedings{
joshi2025benchmarking,
title={Benchmarking Massively Parallelized Multi-Task Reinforcement Learning for Robotics Tasks},
author={Viraj Joshi and Zifan Xu and Bo Liu and Peter Stone and Amy Zhang},
booktitle={Reinforcement Learning Conference},
year={2025},
url={https://openreview.net/forum?id=z0MM0y20I2}
}
```
### Getting SAC to Work on a Massive Parallel Simulator
```bibtex
@article{raffin2025isaacsim,

View File

@ -0,0 +1,148 @@
from __future__ import annotations
import torch
from omegaconf import OmegaConf
import isaacgym
import isaacgymenvs
class MTBenchEnv:
def __init__(
self,
task_name: str,
device_id: int,
num_envs: int,
seed: int,
):
# NOTE: Currently, we only support Meta-World-v2 MT-10/MT-50 in MTBench
task_config = MTBENCH_MW2_CONFIG.copy()
if task_name == "meta-world-v2-mt10":
# MT-10 Setup
assert num_envs == 4096, "MT-10 only supports 4096 environments (for now)"
self.num_tasks = 10
task_config["env"]["tasks"] = [4, 16, 17, 18, 28, 31, 38, 40, 48, 49]
task_config["env"]["taskEnvCount"] = [410] * 6 + [409] * 4
elif task_name == "meta-world-v2-mt50":
# MT-50 Setup
self.num_tasks = 50
assert num_envs == 8192, "MT-50 only supports 8192 environments (for now)"
task_config["env"]["tasks"] = list(range(50))
task_config["env"]["taskEnvCount"] = [164] * 42 + [163] * 8 # 6888 + 1304
else:
raise ValueError(f"Unsupported task name: {task_name}")
task_config["env"]["numEnvs"] = num_envs
task_config["env"]["numObservations"] = 39 + self.num_tasks
task_config["env"]["seed"] = seed
# Convert dictionary to OmegaConf object
env_cfg = {"task": task_config}
env_cfg = OmegaConf.create(env_cfg)
self.env = isaacgymenvs.make(
task=env_cfg.task.name,
num_envs=num_envs,
sim_device=f"cuda:{device_id}",
rl_device=f"cuda:{device_id}",
seed=seed,
headless=True,
cfg=env_cfg,
)
self.num_envs = num_envs
self.asymmetric_obs = False
self.num_obs = self.env.observation_space.shape[0]
assert (
self.num_obs == 39 + self.num_tasks
), "MTBench observation space is 39 + num_tasks (one-hot vector)"
self.num_privileged_obs = 0
self.num_actions = self.env.action_space.shape[0]
self.max_episode_steps = self.env.max_episode_length
def reset(self) -> torch.Tensor:
"""Reset the environment."""
# TODO: Check if we need no_grad and detach here
with torch.no_grad(): # do we need this?
self.env.reset_idx(torch.arange(self.num_envs, device=self.env.device))
obs_dict = self.env.reset()
return obs_dict["obs"].detach()
def step(
self, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
"""Step the environment."""
assert isinstance(actions, torch.Tensor)
# TODO: Check if we need no_grad and detach here
with torch.no_grad():
obs_dict, rew, dones, infos = self.env.step(actions.detach())
truncations = infos["time_outs"]
info_ret = {"time_outs": truncations.detach()}
if "episode" in infos:
info_ret["episode"] = infos["episode"]
# NOTE: There's really no way to get the raw observations from IsaacGym
# We just use the 'reset_obs' as next_obs, unfortunately.
info_ret["observations"] = {"raw": {"obs": obs_dict["obs"].detach()}}
return obs_dict["obs"].detach(), rew.detach(), dones.detach(), info_ret
def render(self):
raise NotImplementedError(
"We don't support rendering for IsaacLab environments"
)
MTBENCH_MW2_CONFIG = {
"name": "meta-world-v2",
"physics_engine": "physx",
"env": {
"numEnvs": 1,
"envSpacing": 1.5,
"episodeLength": 150,
"enableDebugVis": False,
"clipObservations": 5.0,
"clipActions": 1.0,
"aggregateMode": 3,
"actionScale": 0.01,
"resetNoise": 0.15,
"tasks": [0],
"taskEnvCount": [4096],
"init_at_random_progress": True,
"exemptedInitAtRandomProgressTasks": [],
"taskEmbedding": True,
"taskEmbeddingType": "one_hot",
"seed": 42,
"cameraRenderingInterval": 5000,
"cameraWidth": 1024,
"cameraHeight": 1024,
"sparse_reward": False,
"termination_on_success": False,
"reward_scale": 1.0,
"fixed": False,
"numObservations": None,
"numActions": 4,
},
"enableCameraSensors": False,
"sim": {
"dt": 0.01667,
"substeps": 2,
"up_axis": "z",
"use_gpu_pipeline": True,
"gravity": [0.0, 0.0, -9.81],
"physx": {
"num_threads": 4,
"solver_type": 1,
"use_gpu": True,
"num_position_iterations": 8,
"num_velocity_iterations": 1,
"contact_offset": 0.005,
"rest_offset": 0.0,
"bounce_threshold_velocity": 0.2,
"max_depenetration_velocity": 1000.0,
"default_buffer_size_multiplier": 10.0,
"max_gpu_contact_pairs": 1048576,
"num_subscenes": 4,
"contact_collection": 0,
},
},
"task": {"randomize": False},
}

View File

@ -114,6 +114,7 @@ class Critic(nn.Module):
self.register_buffer(
"q_support", torch.linspace(v_min, v_max, num_atoms, device=device)
)
self.device = device
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
return self.qnet1(obs, actions), self.qnet2(obs, actions)
@ -189,6 +190,7 @@ class Actor(nn.Module):
self.register_buffer("std_min", torch.as_tensor(std_min, device=device))
self.register_buffer("std_max", torch.as_tensor(std_max, device=device))
self.n_envs = num_envs
self.device = device
def forward(self, obs: torch.Tensor) -> torch.Tensor:
x = obs
@ -218,3 +220,51 @@ class Actor(nn.Module):
noise = torch.randn_like(act) * self.noise_scales
return act + noise
class MultiTaskActor(Actor):
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_tasks = num_tasks
self.task_embedding_dim = task_embedding_dim
self.task_embedding = nn.Embedding(
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
task_ids_one_hot = obs[..., -self.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
task_embeddings = self.task_embedding(task_indices)
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
return super().forward(obs)
class MultiTaskCritic(Critic):
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_tasks = num_tasks
self.task_embedding_dim = task_embedding_dim
self.task_embedding = nn.Embedding(
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
)
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
task_ids_one_hot = obs[..., -self.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
task_embeddings = self.task_embedding(task_indices)
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
return super().forward(obs, actions)
def projection(
self,
obs: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
bootstrap: torch.Tensor,
discount: float,
) -> torch.Tensor:
task_ids_one_hot = obs[..., -self.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
task_embeddings = self.task_embedding(task_indices)
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
return super().projection(obs, actions, rewards, bootstrap, discount)

View File

@ -365,6 +365,7 @@ class Critic(nn.Module):
self.register_buffer(
"q_support", torch.linspace(v_min, v_max, num_atoms, device=device)
)
self.device = device
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
return self.qnet1(obs, actions), self.qnet2(obs, actions)
@ -449,8 +450,8 @@ class Actor(nn.Module):
self.predictor = HyperTanhPolicy(
hidden_dim=hidden_dim,
action_dim=n_act,
scaler_init=scaler_init,
scaler_scale=scaler_scale,
scaler_init=1.0,
scaler_scale=1.0,
device=device,
)
@ -462,6 +463,7 @@ class Actor(nn.Module):
self.register_buffer("std_min", torch.as_tensor(std_min, device=device))
self.register_buffer("std_max", torch.as_tensor(std_max, device=device))
self.n_envs = num_envs
self.device = device
def forward(self, obs: torch.Tensor) -> torch.Tensor:
x = obs
@ -492,3 +494,51 @@ class Actor(nn.Module):
noise = torch.randn_like(act) * self.noise_scales
return act + noise
class MultiTaskActor(Actor):
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_tasks = num_tasks
self.task_embedding_dim = task_embedding_dim
self.task_embedding = nn.Embedding(
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
task_ids_one_hot = obs[..., -self.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
task_embeddings = self.task_embedding(task_indices)
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
return super().forward(obs)
class MultiTaskCritic(Critic):
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_tasks = num_tasks
self.task_embedding_dim = task_embedding_dim
self.task_embedding = nn.Embedding(
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
)
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
task_ids_one_hot = obs[..., -self.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
task_embeddings = self.task_embedding(task_indices)
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
return super().forward(obs, actions)
def projection(
self,
obs: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
bootstrap: torch.Tensor,
discount: float,
) -> torch.Tensor:
task_ids_one_hot = obs[..., -self.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
task_embeddings = self.task_embedding(task_indices)
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
return super().projection(obs, actions, rewards, bootstrap, discount)

View File

@ -512,6 +512,212 @@ class RewardNormalizer(nn.Module):
return self._scale_reward(rewards)
class PerTaskEmpiricalNormalization(nn.Module):
"""Normalize mean and variance of values based on empirical values for each task."""
def __init__(
self,
num_tasks: int,
shape: tuple,
device: torch.device,
eps: float = 1e-2,
until: int = None,
):
"""
Initialize PerTaskEmpiricalNormalization module.
Args:
num_tasks (int): The total number of tasks.
shape (int or tuple of int): Shape of input values except batch axis.
eps (float): Small value for stability.
until (int or None): If specified, learns until the sum of batch sizes
for a specific task exceeds this value.
"""
super().__init__()
if not isinstance(shape, tuple):
shape = (shape,)
self.num_tasks = num_tasks
self.shape = shape
self.eps = eps
self.until = until
self.device = device
# Buffers now have a leading dimension for tasks
self.register_buffer("_mean", torch.zeros(num_tasks, *shape).to(device))
self.register_buffer("_var", torch.ones(num_tasks, *shape).to(device))
self.register_buffer("_std", torch.ones(num_tasks, *shape).to(device))
self.register_buffer(
"count", torch.zeros(num_tasks, dtype=torch.long).to(device)
)
def forward(
self, x: torch.Tensor, task_ids: torch.Tensor, center: bool = True
) -> torch.Tensor:
"""
Normalize the input tensor `x` using statistics for the given `task_ids`.
Args:
x (torch.Tensor): Input tensor of shape [num_envs, *shape].
task_ids (torch.Tensor): Tensor of task indices, shape [num_envs].
center (bool): If True, center the data by subtracting the mean.
"""
if x.shape[1:] != self.shape:
raise ValueError(f"Expected input shape (*, {self.shape}), got {x.shape}")
if x.shape[0] != task_ids.shape[0]:
raise ValueError("Batch size of x and task_ids must match.")
# Gather the stats for the tasks in the current batch
# Reshape task_ids for broadcasting: [num_envs] -> [num_envs, 1, ...]
view_shape = (task_ids.shape[0],) + (1,) * len(self.shape)
task_ids_expanded = task_ids.view(view_shape).expand_as(x)
mean = self._mean.gather(0, task_ids_expanded)
std = self._std.gather(0, task_ids_expanded)
if self.training:
self.update(x, task_ids)
if center:
return (x - mean) / (std + self.eps)
else:
return x / (std + self.eps)
@torch.jit.unused
def update(self, x: torch.Tensor, task_ids: torch.Tensor):
"""Update running statistics for the tasks present in the batch."""
unique_tasks = torch.unique(task_ids)
for task_id in unique_tasks:
if self.until is not None and self.count[task_id] >= self.until:
continue
# Create a mask to select data for the current task
mask = task_ids == task_id
x_task = x[mask]
batch_size = x_task.shape[0]
if batch_size == 0:
continue
# Update count for this task
old_count = self.count[task_id].clone()
new_count = old_count + batch_size
# Update mean
task_mean = self._mean[task_id]
batch_mean = torch.mean(x_task, dim=0)
delta = batch_mean - task_mean
self._mean[task_id] = task_mean + (batch_size / new_count) * delta
# Update variance using Chan's parallel algorithm
if old_count > 0:
batch_var = torch.var(x_task, dim=0, unbiased=False)
m_a = self._var[task_id] * old_count
m_b = batch_var * batch_size
M2 = m_a + m_b + (delta**2) * (old_count * batch_size / new_count)
self._var[task_id] = M2 / new_count
else:
# For the first batch of this task
self._var[task_id] = torch.var(x_task, dim=0, unbiased=False)
self._std[task_id] = torch.sqrt(self._var[task_id])
self.count[task_id] = new_count
class PerTaskRewardNormalizer(nn.Module):
def __init__(
self,
num_tasks: int,
gamma: float,
device: torch.device,
g_max: float = 10.0,
epsilon: float = 1e-8,
):
"""
Per-task reward normalizer, motivation comes from BRC (https://arxiv.org/abs/2505.23150v1)
"""
super().__init__()
self.num_tasks = num_tasks
self.gamma = gamma
self.g_max = g_max
self.epsilon = epsilon
self.device = device
# Per-task running estimate of the discounted return
self.register_buffer("G", torch.zeros(num_tasks, device=device))
# Per-task running-max of the discounted return
self.register_buffer("G_r_max", torch.zeros(num_tasks, device=device))
# Use the new per-task normalizer for the statistics of G
self.G_rms = PerTaskEmpiricalNormalization(
num_tasks=num_tasks, shape=(1,), device=device
)
def _scale_reward(
self, rewards: torch.Tensor, task_ids: torch.Tensor
) -> torch.Tensor:
"""
Scales rewards using per-task statistics.
Args:
rewards (torch.Tensor): Reward tensor, shape [num_envs].
task_ids (torch.Tensor): Task indices, shape [num_envs].
"""
# Gather stats for the tasks in the batch
std_for_batch = self.G_rms._std.gather(0, task_ids.unsqueeze(-1)).squeeze(-1)
g_r_max_for_batch = self.G_r_max.gather(0, task_ids)
var_denominator = std_for_batch + self.epsilon
min_required_denominator = g_r_max_for_batch / self.g_max
denominator = torch.maximum(var_denominator, min_required_denominator)
# Add a small epsilon to the final denominator to prevent division by zero
# in case g_r_max is also zero.
return rewards / (denominator + self.epsilon)
def update_stats(
self, rewards: torch.Tensor, dones: torch.Tensor, task_ids: torch.Tensor
):
"""
Updates the running discounted return and its statistics for each task.
Args:
rewards (torch.Tensor): Reward tensor, shape [num_envs].
dones (torch.Tensor): Done tensor, shape [num_envs].
task_ids (torch.Tensor): Task indices, shape [num_envs].
"""
if not (rewards.shape == dones.shape == task_ids.shape):
raise ValueError("rewards, dones, and task_ids must have the same shape.")
# === Update G (running discounted return) ===
# Gather the previous G values for the tasks in the batch
prev_G = self.G.gather(0, task_ids)
# Update G for each environment based on its own reward and done signal
new_G = self.gamma * (1 - dones.float()) * prev_G + rewards
# Scatter the updated G values back to the main buffer
self.G.scatter_(0, task_ids, new_G)
# === Update G_rms (statistics of G) ===
# The update function handles the per-task logic internally
self.G_rms.update(new_G.unsqueeze(-1), task_ids)
# === Update G_r_max (running max of |G|) ===
prev_G_r_max = self.G_r_max.gather(0, task_ids)
# Update the max for each environment
updated_G_r_max = torch.maximum(prev_G_r_max, torch.abs(new_G))
# Scatter the new maxes back to the main buffer
self.G_r_max.scatter_(0, task_ids, updated_G_r_max)
def forward(self, rewards: torch.Tensor, task_ids: torch.Tensor) -> torch.Tensor:
"""
Normalizes rewards. During training, it also updates the running statistics.
Args:
rewards (torch.Tensor): Reward tensor, shape [num_envs].
task_ids (torch.Tensor): Task indices, shape [num_envs].
"""
return self._scale_reward(rewards, task_ids)
def cpu_state(sd):
# detach & move to host without locking the compute stream
return {k: v.detach().to("cpu", non_blocking=True) for k, v in sd.items()}

View File

@ -95,7 +95,7 @@ class BaseArgs:
obs_normalization: bool = True
"""whether to enable observation normalization"""
reward_normalization: bool = False
"""whether to enable reward normalization (Not recommended for now, it's unstable.)"""
"""whether to enable reward normalization"""
max_grad_norm: float = 0.0
"""the maximum gradient norm"""
amp: bool = True
@ -113,6 +113,8 @@ class BaseArgs:
"""(Playground only) Use tuned reward for G1"""
action_bounds: float = 1.0
"""(IsaacLab only) the bounds of the action space (-action_bounds, action_bounds)"""
task_embedding_dim: int = 32
"""the dimension of the task embedding"""
weight_decay: float = 0.1
"""the weight decay of the optimizer"""
@ -169,6 +171,9 @@ def get_args():
"Isaac-Velocity-Rough-G1-v0": IsaacVelocityRoughG1Args,
"Isaac-Repose-Cube-Allegro-Direct-v0": IsaacReposeCubeAllegroDirectArgs,
"Isaac-Repose-Cube-Shadow-Direct-v0": IsaacReposeCubeShadowDirectArgs,
# MTBench
"MTBench-meta-world-v2-mt10": MetaWorldMT10Args,
"MTBench-meta-world-v2-mt50": MetaWorldMT50Args,
}
# If the provided env_name has a specific Args class, use it
if base_args.env_name in env_to_args_class:
@ -183,6 +188,9 @@ def get_args():
elif base_args.env_name.startswith("Isaac-"):
# IsaacLab
specific_args = tyro.cli(IsaacLabArgs)
elif base_args.env_name.startswith("MTBench-"):
# MTBench
specific_args = tyro.cli(MTBenchArgs)
else:
# MuJoCo Playground
specific_args = tyro.cli(MuJoCoPlaygroundArgs)
@ -280,6 +288,38 @@ class MuJoCoPlaygroundArgs(BaseArgs):
gamma: float = 0.97
@dataclass
class MTBenchArgs(BaseArgs):
# Default hyperparameters for MTBench
reward_normalization: bool = True
v_min: float = -10.0
v_max: float = 10.0
buffer_size: int = 2048 # 2K is usually enough for MTBench
num_envs: int = 4096
num_eval_envs: int = 4096
gamma: float = 0.99
num_steps: int = 8
@dataclass
class MetaWorldMT10Args(MTBenchArgs):
# This config achieves 97 ~ 98% success rate within 10k steps (15-20 mins on A100)
env_name: str = "MTBench-meta-world-v2-mt10"
num_envs: int = 4096
num_eval_envs: int = 4096
num_steps: int = 8
@dataclass
class MetaWorldMT50Args(MTBenchArgs):
# FastTD3 + SimbaV2 achieves >90% success rate within 20k steps (80 mins on A100)
# Performance further improves with more training steps, slowly.
env_name: str = "MTBench-meta-world-v2-mt50"
num_envs: int = 8192
num_eval_envs: int = 8192
num_steps: int = 8
@dataclass
class G1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "G1JoystickFlatTerrain"

View File

@ -18,17 +18,24 @@ import tqdm
import wandb
import numpy as np
try:
# Required for avoiding IsaacGym import error
import isaacgym
except ImportError:
pass
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.amp import autocast, GradScaler
from tensordict import TensorDict, from_module
from tensordict import TensorDict
from fast_td3_utils import (
EmpiricalNormalization,
RewardNormalizer,
PerTaskRewardNormalizer,
SimpleReplayBuffer,
save_params,
)
@ -103,6 +110,14 @@ def main():
)
eval_envs = envs
render_env = envs
elif args.env_name.startswith("MTBench-"):
from environments.mtbench_env import MTBenchEnv
env_name = "-".join(args.env_name.split("-")[1:])
env_type = "mtbench"
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
eval_envs = envs
render_env = envs
else:
from environments.mujoco_playground_env import make_env
@ -141,8 +156,18 @@ def main():
critic_obs_normalizer = nn.Identity()
if args.reward_normalization:
if env_type in ["mtbench"]:
reward_normalizer = PerTaskRewardNormalizer(
num_tasks=envs.num_tasks,
gamma=args.gamma,
device=device,
g_max=min(abs(args.v_min), abs(args.v_max)),
)
else:
reward_normalizer = RewardNormalizer(
gamma=args.gamma, device=device, g_max=min(abs(args.v_min), abs(args.v_max))
gamma=args.gamma,
device=device,
g_max=min(abs(args.v_min), abs(args.v_max)),
)
else:
reward_normalizer = nn.Identity()
@ -165,13 +190,38 @@ def main():
"device": device,
}
if env_type == "mtbench":
actor_kwargs["n_obs"] = n_obs - envs.num_tasks + args.task_embedding_dim
critic_kwargs["n_obs"] = n_critic_obs - envs.num_tasks + args.task_embedding_dim
actor_kwargs["num_tasks"] = envs.num_tasks
actor_kwargs["task_embedding_dim"] = args.task_embedding_dim
critic_kwargs["num_tasks"] = envs.num_tasks
critic_kwargs["task_embedding_dim"] = args.task_embedding_dim
if args.agent == "fasttd3":
if env_type in ["mtbench"]:
from fast_td3 import MultiTaskActor, MultiTaskCritic
actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic
else:
from fast_td3 import Actor, Critic
actor_cls = Actor
critic_cls = Critic
print("Using FastTD3")
elif args.agent == "fasttd3_simbav2":
if env_type in ["mtbench"]:
from fast_td3_simbav2 import MultiTaskActor, MultiTaskCritic
actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic
else:
from fast_td3_simbav2 import Actor, Critic
actor_cls = Actor
print("Using FastTD3 + SimbaV2")
actor_kwargs.pop("init_scale")
actor_kwargs.update(
@ -199,25 +249,31 @@ def main():
else:
raise ValueError(f"Agent {args.agent} not supported")
actor = Actor(**actor_kwargs)
actor_detach = Actor(**actor_kwargs)
actor = actor_cls(**actor_kwargs)
if env_type in ["mtbench"]:
# Python 3.8 doesn't support 'from_module' in tensordict
policy = actor.explore
else:
from tensordict import from_module
actor_detach = actor_cls(**actor_kwargs)
# Copy params to actor_detach without grad
from_module(actor).data.to_module(actor_detach)
policy = actor_detach.explore
qnet = Critic(**critic_kwargs)
qnet_target = Critic(**critic_kwargs)
qnet = critic_cls(**critic_kwargs)
qnet_target = critic_cls(**critic_kwargs)
qnet_target.load_state_dict(qnet.state_dict())
q_optimizer = optim.AdamW(
list(qnet.parameters()),
lr=args.critic_learning_rate,
lr=torch.tensor(args.critic_learning_rate, device=device),
weight_decay=args.weight_decay,
)
actor_optimizer = optim.AdamW(
list(actor.parameters()),
lr=args.actor_learning_rate,
lr=torch.tensor(args.actor_learning_rate, device=device),
weight_decay=args.weight_decay,
)
@ -225,12 +281,12 @@ def main():
q_scheduler = optim.lr_scheduler.CosineAnnealingLR(
q_optimizer,
T_max=args.total_timesteps,
eta_min=args.critic_learning_rate_end, # Decay to 10% of initial lr
eta_min=torch.tensor(args.critic_learning_rate_end, device=device),
)
actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(
actor_optimizer,
T_max=args.total_timesteps,
eta_min=args.actor_learning_rate_end, # Decay to 10% of initial lr
eta_min=torch.tensor(args.actor_learning_rate_end, device=device),
)
rb = SimpleReplayBuffer(
@ -262,20 +318,28 @@ def main():
obs = eval_envs.reset()
# Run for a fixed number of steps
for _ in range(eval_envs.max_episode_steps):
for i in range(eval_envs.max_episode_steps):
with torch.no_grad(), autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
obs = normalize_obs(obs)
actions = actor(obs)
next_obs, rewards, dones, _ = eval_envs.step(actions.float())
next_obs, rewards, dones, infos = eval_envs.step(actions.float())
if env_type == "mtbench":
# We only report success rate in MTBench evaluation
rewards = (
infos["episode"]["success"].float() if "episode" in infos else 0.0
)
episode_returns = torch.where(
~done_masks, episode_returns + rewards, episode_returns
)
episode_lengths = torch.where(
~done_masks, episode_lengths + 1, episode_lengths
)
if env_type == "mtbench" and "episode" in infos:
dones = dones | infos["episode"]["success"]
done_masks = torch.logical_or(done_masks, dones)
if done_masks.all():
break
@ -291,9 +355,9 @@ def main():
if env_type == "humanoid_bench":
obs = render_env.reset()
renders = [render_env.render()]
elif env_type == "isaaclab":
elif env_type in ["isaaclab", "mtbench"]:
raise NotImplementedError(
"We don't support rendering for IsaacLab environments"
"We don't support rendering for IsaacLab and MTBench environments"
)
else:
obs = render_env.reset()
@ -399,6 +463,7 @@ def main():
)
scaler.step(q_optimizer)
scaler.update()
q_scheduler.step()
logs_dict["critic_grad_norm"] = critic_grad_norm.detach()
logs_dict["qf_loss"] = qf_loss.detach()
@ -434,6 +499,7 @@ def main():
)
scaler.step(actor_optimizer)
scaler.update()
actor_scheduler.step()
logs_dict["actor_grad_norm"] = actor_grad_norm.detach()
logs_dict["actor_loss"] = actor_loss.detach()
return logs_dict
@ -500,6 +566,11 @@ def main():
truncations = infos["time_outs"]
if args.reward_normalization:
if env_type == "mtbench":
task_ids_one_hot = obs[..., -envs.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
update_stats(rewards, dones.float(), task_ids=task_indices)
else:
update_stats(rewards, dones.float())
if envs.asymmetric_obs:
@ -550,6 +621,14 @@ def main():
data["next"]["observations"]
)
raw_rewards = data["next"]["rewards"]
if env_type in ["mtbench"] and args.reward_normalization:
# Multi-task reward normalization
task_ids_one_hot = data["observations"][..., -envs.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
data["next"]["rewards"] = normalize_reward(
raw_rewards, task_ids=task_indices
)
else:
data["next"]["rewards"] = normalize_reward(raw_rewards)
if envs.asymmetric_obs:
data["critic_observations"] = normalize_critic_obs(
@ -591,7 +670,7 @@ def main():
if args.eval_interval > 0 and global_step % args.eval_interval == 0:
print(f"Evaluating at global step {global_step}")
eval_avg_return, eval_avg_length = evaluate()
if env_type in ["humanoid_bench", "isaaclab"]:
if env_type in ["humanoid_bench", "isaaclab", "mtbench"]:
# NOTE: Hacky way of evaluating performance, but just works
obs = envs.reset()
logs["eval_avg_return"] = eval_avg_return
@ -644,10 +723,6 @@ def main():
f"models/{run_name}_{global_step}.pt",
)
# Update learning rates
q_scheduler.step()
actor_scheduler.step()
global_step += 1
pbar.update(1)

View File

@ -0,0 +1,15 @@
gymnasium<1.0.0
jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11"
matplotlib
moviepy
numpy<2.0
pandas
protobuf
pygame
stable-baselines3
tqdm
wandb
torchrl==0.5.0
tensordict==0.5.0
tyro
loguru