Support MTBench (#15)
This PR incorporates MTBench into the current codebase, as a good demonstration that shows how to use FastTD3 for multi-task setup. - Add support for MTBench along with its wrapper - Add support for per-task reward normalizer useful for multi-task RL, motivated by BRC paper (https://arxiv.org/abs/2505.23150v1)
This commit is contained in:
parent
3facede77d
commit
cef44108d8
67
README.md
67
README.md
@ -10,11 +10,13 @@ For more information, please see our [project webpage](https://younggyo.me/fast_
|
||||
|
||||
|
||||
## ❗ Updates
|
||||
- **[June/15/2025]** Added support for FastTD3 + SimbaV2! It's faster to train, and often achieves better asymptotic performance.
|
||||
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench)
|
||||
|
||||
- **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
|
||||
- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases.
|
||||
|
||||
- **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
|
||||
- **[Jun/06/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
|
||||
|
||||
- **[Jun/01/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
|
||||
|
||||
|
||||
## ✨ Features
|
||||
@ -81,6 +83,27 @@ cd ..
|
||||
pip install -r requirements/requirements.txt
|
||||
```
|
||||
|
||||
### Environment for MTBench
|
||||
MTBench does not support humanoid experiments, but is a useful multi-task benchmark with massive parallel simulation. This could be useful for users who want to use FastTD3 for their multi-task experiments.
|
||||
|
||||
```bash
|
||||
conda create -n fasttd3_mtbench -y python=3.8 # Note python version
|
||||
conda activate fasttd3_mtbench
|
||||
|
||||
# Install IsaacGym -- recommend to follow instructions in https://github.com/BoosterRobotics/booster_gym
|
||||
...
|
||||
|
||||
# Install MTBench
|
||||
git clone https://github.com/Viraj-Joshi/MTBench.git
|
||||
cd MTbench
|
||||
pip install -e .
|
||||
pip install skrl
|
||||
cd ..
|
||||
|
||||
# Install project-specific requirements
|
||||
pip install -r requirements/requirements_isaacgym.txt
|
||||
```
|
||||
|
||||
### (Optional) Accelerate headless GPU rendering in cloud instances
|
||||
|
||||
In some cloud VM images the NVIDIA kernel driver is present but the user-space OpenGL/EGL/Vulkan libraries aren't, so MuJoCo falls back to CPU renderer. You can install just the NVIDIA user-space libraries (and skip rebuilding the kernel module) with:
|
||||
@ -176,6 +199,32 @@ python fast_td3/train.py \
|
||||
--seed 1
|
||||
```
|
||||
|
||||
### MTBench Experiments
|
||||
```bash
|
||||
conda activate fasttd3_mtbench
|
||||
# FastTD3
|
||||
python fast_td3/train.py \
|
||||
--env_name MTBench-meta-world-v2-mt10 \
|
||||
--exp_name FastTD3 \
|
||||
--render_interval 0 \
|
||||
--seed 1
|
||||
# FastTD3 + SimbaV2
|
||||
python fast_td3/train.py \
|
||||
--env_name MTBench-meta-world-v2-mt10 \
|
||||
--exp_name FastTD3 \
|
||||
--render_interval 0 \
|
||||
--agent fasttd3_simbav2 \
|
||||
--batch_size 8192 \
|
||||
--critic_learning_rate_end 3e-5 \
|
||||
--actor_learning_rate_end 3e-5 \
|
||||
--weight_decay 0.0 \
|
||||
--critic_hidden_dim 512 \
|
||||
--critic_num_blocks 2 \
|
||||
--actor_hidden_dim 256 \
|
||||
--actor_num_blocks 1 \
|
||||
--seed 1
|
||||
```
|
||||
|
||||
**Quick note:** For boolean-based arguments, you can set them to False by adding `no_` in front each argument, for instance, if you want to disable Clipped Q Learning, you can specify `--no_use_cdq` in your command.
|
||||
|
||||
## 💡 Performance-Related Tips
|
||||
@ -315,6 +364,18 @@ Following the [LeanRL](https://github.com/pytorch-labs/LeanRL)'s recommendation,
|
||||
}
|
||||
```
|
||||
|
||||
### MTBench
|
||||
```bibtex
|
||||
@inproceedings{
|
||||
joshi2025benchmarking,
|
||||
title={Benchmarking Massively Parallelized Multi-Task Reinforcement Learning for Robotics Tasks},
|
||||
author={Viraj Joshi and Zifan Xu and Bo Liu and Peter Stone and Amy Zhang},
|
||||
booktitle={Reinforcement Learning Conference},
|
||||
year={2025},
|
||||
url={https://openreview.net/forum?id=z0MM0y20I2}
|
||||
}
|
||||
```
|
||||
|
||||
### Getting SAC to Work on a Massive Parallel Simulator
|
||||
```bibtex
|
||||
@article{raffin2025isaacsim,
|
||||
|
148
fast_td3/environments/mtbench_env.py
Normal file
148
fast_td3/environments/mtbench_env.py
Normal file
@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import isaacgym
|
||||
import isaacgymenvs
|
||||
|
||||
|
||||
class MTBenchEnv:
|
||||
def __init__(
|
||||
self,
|
||||
task_name: str,
|
||||
device_id: int,
|
||||
num_envs: int,
|
||||
seed: int,
|
||||
):
|
||||
# NOTE: Currently, we only support Meta-World-v2 MT-10/MT-50 in MTBench
|
||||
task_config = MTBENCH_MW2_CONFIG.copy()
|
||||
if task_name == "meta-world-v2-mt10":
|
||||
# MT-10 Setup
|
||||
assert num_envs == 4096, "MT-10 only supports 4096 environments (for now)"
|
||||
self.num_tasks = 10
|
||||
task_config["env"]["tasks"] = [4, 16, 17, 18, 28, 31, 38, 40, 48, 49]
|
||||
task_config["env"]["taskEnvCount"] = [410] * 6 + [409] * 4
|
||||
elif task_name == "meta-world-v2-mt50":
|
||||
# MT-50 Setup
|
||||
self.num_tasks = 50
|
||||
assert num_envs == 8192, "MT-50 only supports 8192 environments (for now)"
|
||||
task_config["env"]["tasks"] = list(range(50))
|
||||
task_config["env"]["taskEnvCount"] = [164] * 42 + [163] * 8 # 6888 + 1304
|
||||
else:
|
||||
raise ValueError(f"Unsupported task name: {task_name}")
|
||||
task_config["env"]["numEnvs"] = num_envs
|
||||
task_config["env"]["numObservations"] = 39 + self.num_tasks
|
||||
task_config["env"]["seed"] = seed
|
||||
|
||||
# Convert dictionary to OmegaConf object
|
||||
env_cfg = {"task": task_config}
|
||||
env_cfg = OmegaConf.create(env_cfg)
|
||||
|
||||
self.env = isaacgymenvs.make(
|
||||
task=env_cfg.task.name,
|
||||
num_envs=num_envs,
|
||||
sim_device=f"cuda:{device_id}",
|
||||
rl_device=f"cuda:{device_id}",
|
||||
seed=seed,
|
||||
headless=True,
|
||||
cfg=env_cfg,
|
||||
)
|
||||
|
||||
self.num_envs = num_envs
|
||||
self.asymmetric_obs = False
|
||||
self.num_obs = self.env.observation_space.shape[0]
|
||||
assert (
|
||||
self.num_obs == 39 + self.num_tasks
|
||||
), "MTBench observation space is 39 + num_tasks (one-hot vector)"
|
||||
self.num_privileged_obs = 0
|
||||
self.num_actions = self.env.action_space.shape[0]
|
||||
self.max_episode_steps = self.env.max_episode_length
|
||||
|
||||
def reset(self) -> torch.Tensor:
|
||||
"""Reset the environment."""
|
||||
# TODO: Check if we need no_grad and detach here
|
||||
with torch.no_grad(): # do we need this?
|
||||
self.env.reset_idx(torch.arange(self.num_envs, device=self.env.device))
|
||||
obs_dict = self.env.reset()
|
||||
return obs_dict["obs"].detach()
|
||||
|
||||
def step(
|
||||
self, actions: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
|
||||
"""Step the environment."""
|
||||
assert isinstance(actions, torch.Tensor)
|
||||
|
||||
# TODO: Check if we need no_grad and detach here
|
||||
with torch.no_grad():
|
||||
obs_dict, rew, dones, infos = self.env.step(actions.detach())
|
||||
truncations = infos["time_outs"]
|
||||
info_ret = {"time_outs": truncations.detach()}
|
||||
if "episode" in infos:
|
||||
info_ret["episode"] = infos["episode"]
|
||||
# NOTE: There's really no way to get the raw observations from IsaacGym
|
||||
# We just use the 'reset_obs' as next_obs, unfortunately.
|
||||
info_ret["observations"] = {"raw": {"obs": obs_dict["obs"].detach()}}
|
||||
return obs_dict["obs"].detach(), rew.detach(), dones.detach(), info_ret
|
||||
|
||||
def render(self):
|
||||
raise NotImplementedError(
|
||||
"We don't support rendering for IsaacLab environments"
|
||||
)
|
||||
|
||||
|
||||
MTBENCH_MW2_CONFIG = {
|
||||
"name": "meta-world-v2",
|
||||
"physics_engine": "physx",
|
||||
"env": {
|
||||
"numEnvs": 1,
|
||||
"envSpacing": 1.5,
|
||||
"episodeLength": 150,
|
||||
"enableDebugVis": False,
|
||||
"clipObservations": 5.0,
|
||||
"clipActions": 1.0,
|
||||
"aggregateMode": 3,
|
||||
"actionScale": 0.01,
|
||||
"resetNoise": 0.15,
|
||||
"tasks": [0],
|
||||
"taskEnvCount": [4096],
|
||||
"init_at_random_progress": True,
|
||||
"exemptedInitAtRandomProgressTasks": [],
|
||||
"taskEmbedding": True,
|
||||
"taskEmbeddingType": "one_hot",
|
||||
"seed": 42,
|
||||
"cameraRenderingInterval": 5000,
|
||||
"cameraWidth": 1024,
|
||||
"cameraHeight": 1024,
|
||||
"sparse_reward": False,
|
||||
"termination_on_success": False,
|
||||
"reward_scale": 1.0,
|
||||
"fixed": False,
|
||||
"numObservations": None,
|
||||
"numActions": 4,
|
||||
},
|
||||
"enableCameraSensors": False,
|
||||
"sim": {
|
||||
"dt": 0.01667,
|
||||
"substeps": 2,
|
||||
"up_axis": "z",
|
||||
"use_gpu_pipeline": True,
|
||||
"gravity": [0.0, 0.0, -9.81],
|
||||
"physx": {
|
||||
"num_threads": 4,
|
||||
"solver_type": 1,
|
||||
"use_gpu": True,
|
||||
"num_position_iterations": 8,
|
||||
"num_velocity_iterations": 1,
|
||||
"contact_offset": 0.005,
|
||||
"rest_offset": 0.0,
|
||||
"bounce_threshold_velocity": 0.2,
|
||||
"max_depenetration_velocity": 1000.0,
|
||||
"default_buffer_size_multiplier": 10.0,
|
||||
"max_gpu_contact_pairs": 1048576,
|
||||
"num_subscenes": 4,
|
||||
"contact_collection": 0,
|
||||
},
|
||||
},
|
||||
"task": {"randomize": False},
|
||||
}
|
@ -114,6 +114,7 @@ class Critic(nn.Module):
|
||||
self.register_buffer(
|
||||
"q_support", torch.linspace(v_min, v_max, num_atoms, device=device)
|
||||
)
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
return self.qnet1(obs, actions), self.qnet2(obs, actions)
|
||||
@ -189,6 +190,7 @@ class Actor(nn.Module):
|
||||
self.register_buffer("std_min", torch.as_tensor(std_min, device=device))
|
||||
self.register_buffer("std_max", torch.as_tensor(std_max, device=device))
|
||||
self.n_envs = num_envs
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
x = obs
|
||||
@ -218,3 +220,51 @@ class Actor(nn.Module):
|
||||
|
||||
noise = torch.randn_like(act) * self.noise_scales
|
||||
return act + noise
|
||||
|
||||
|
||||
class MultiTaskActor(Actor):
|
||||
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_tasks = num_tasks
|
||||
self.task_embedding_dim = task_embedding_dim
|
||||
self.task_embedding = nn.Embedding(
|
||||
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
|
||||
)
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||
task_embeddings = self.task_embedding(task_indices)
|
||||
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||
return super().forward(obs)
|
||||
|
||||
|
||||
class MultiTaskCritic(Critic):
|
||||
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_tasks = num_tasks
|
||||
self.task_embedding_dim = task_embedding_dim
|
||||
self.task_embedding = nn.Embedding(
|
||||
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
|
||||
)
|
||||
|
||||
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||
task_embeddings = self.task_embedding(task_indices)
|
||||
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||
return super().forward(obs, actions)
|
||||
|
||||
def projection(
|
||||
self,
|
||||
obs: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
rewards: torch.Tensor,
|
||||
bootstrap: torch.Tensor,
|
||||
discount: float,
|
||||
) -> torch.Tensor:
|
||||
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||
task_embeddings = self.task_embedding(task_indices)
|
||||
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||
return super().projection(obs, actions, rewards, bootstrap, discount)
|
||||
|
@ -365,6 +365,7 @@ class Critic(nn.Module):
|
||||
self.register_buffer(
|
||||
"q_support", torch.linspace(v_min, v_max, num_atoms, device=device)
|
||||
)
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
return self.qnet1(obs, actions), self.qnet2(obs, actions)
|
||||
@ -449,8 +450,8 @@ class Actor(nn.Module):
|
||||
self.predictor = HyperTanhPolicy(
|
||||
hidden_dim=hidden_dim,
|
||||
action_dim=n_act,
|
||||
scaler_init=scaler_init,
|
||||
scaler_scale=scaler_scale,
|
||||
scaler_init=1.0,
|
||||
scaler_scale=1.0,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@ -462,6 +463,7 @@ class Actor(nn.Module):
|
||||
self.register_buffer("std_min", torch.as_tensor(std_min, device=device))
|
||||
self.register_buffer("std_max", torch.as_tensor(std_max, device=device))
|
||||
self.n_envs = num_envs
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
x = obs
|
||||
@ -492,3 +494,51 @@ class Actor(nn.Module):
|
||||
|
||||
noise = torch.randn_like(act) * self.noise_scales
|
||||
return act + noise
|
||||
|
||||
|
||||
class MultiTaskActor(Actor):
|
||||
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_tasks = num_tasks
|
||||
self.task_embedding_dim = task_embedding_dim
|
||||
self.task_embedding = nn.Embedding(
|
||||
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
|
||||
)
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||
task_embeddings = self.task_embedding(task_indices)
|
||||
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||
return super().forward(obs)
|
||||
|
||||
|
||||
class MultiTaskCritic(Critic):
|
||||
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_tasks = num_tasks
|
||||
self.task_embedding_dim = task_embedding_dim
|
||||
self.task_embedding = nn.Embedding(
|
||||
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
|
||||
)
|
||||
|
||||
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||
task_embeddings = self.task_embedding(task_indices)
|
||||
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||
return super().forward(obs, actions)
|
||||
|
||||
def projection(
|
||||
self,
|
||||
obs: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
rewards: torch.Tensor,
|
||||
bootstrap: torch.Tensor,
|
||||
discount: float,
|
||||
) -> torch.Tensor:
|
||||
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||
task_embeddings = self.task_embedding(task_indices)
|
||||
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||
return super().projection(obs, actions, rewards, bootstrap, discount)
|
||||
|
@ -512,6 +512,212 @@ class RewardNormalizer(nn.Module):
|
||||
return self._scale_reward(rewards)
|
||||
|
||||
|
||||
class PerTaskEmpiricalNormalization(nn.Module):
|
||||
"""Normalize mean and variance of values based on empirical values for each task."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_tasks: int,
|
||||
shape: tuple,
|
||||
device: torch.device,
|
||||
eps: float = 1e-2,
|
||||
until: int = None,
|
||||
):
|
||||
"""
|
||||
Initialize PerTaskEmpiricalNormalization module.
|
||||
|
||||
Args:
|
||||
num_tasks (int): The total number of tasks.
|
||||
shape (int or tuple of int): Shape of input values except batch axis.
|
||||
eps (float): Small value for stability.
|
||||
until (int or None): If specified, learns until the sum of batch sizes
|
||||
for a specific task exceeds this value.
|
||||
"""
|
||||
super().__init__()
|
||||
if not isinstance(shape, tuple):
|
||||
shape = (shape,)
|
||||
self.num_tasks = num_tasks
|
||||
self.shape = shape
|
||||
self.eps = eps
|
||||
self.until = until
|
||||
self.device = device
|
||||
|
||||
# Buffers now have a leading dimension for tasks
|
||||
self.register_buffer("_mean", torch.zeros(num_tasks, *shape).to(device))
|
||||
self.register_buffer("_var", torch.ones(num_tasks, *shape).to(device))
|
||||
self.register_buffer("_std", torch.ones(num_tasks, *shape).to(device))
|
||||
self.register_buffer(
|
||||
"count", torch.zeros(num_tasks, dtype=torch.long).to(device)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, task_ids: torch.Tensor, center: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Normalize the input tensor `x` using statistics for the given `task_ids`.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape [num_envs, *shape].
|
||||
task_ids (torch.Tensor): Tensor of task indices, shape [num_envs].
|
||||
center (bool): If True, center the data by subtracting the mean.
|
||||
"""
|
||||
if x.shape[1:] != self.shape:
|
||||
raise ValueError(f"Expected input shape (*, {self.shape}), got {x.shape}")
|
||||
if x.shape[0] != task_ids.shape[0]:
|
||||
raise ValueError("Batch size of x and task_ids must match.")
|
||||
|
||||
# Gather the stats for the tasks in the current batch
|
||||
# Reshape task_ids for broadcasting: [num_envs] -> [num_envs, 1, ...]
|
||||
view_shape = (task_ids.shape[0],) + (1,) * len(self.shape)
|
||||
task_ids_expanded = task_ids.view(view_shape).expand_as(x)
|
||||
|
||||
mean = self._mean.gather(0, task_ids_expanded)
|
||||
std = self._std.gather(0, task_ids_expanded)
|
||||
|
||||
if self.training:
|
||||
self.update(x, task_ids)
|
||||
|
||||
if center:
|
||||
return (x - mean) / (std + self.eps)
|
||||
else:
|
||||
return x / (std + self.eps)
|
||||
|
||||
@torch.jit.unused
|
||||
def update(self, x: torch.Tensor, task_ids: torch.Tensor):
|
||||
"""Update running statistics for the tasks present in the batch."""
|
||||
unique_tasks = torch.unique(task_ids)
|
||||
|
||||
for task_id in unique_tasks:
|
||||
if self.until is not None and self.count[task_id] >= self.until:
|
||||
continue
|
||||
|
||||
# Create a mask to select data for the current task
|
||||
mask = task_ids == task_id
|
||||
x_task = x[mask]
|
||||
batch_size = x_task.shape[0]
|
||||
|
||||
if batch_size == 0:
|
||||
continue
|
||||
|
||||
# Update count for this task
|
||||
old_count = self.count[task_id].clone()
|
||||
new_count = old_count + batch_size
|
||||
|
||||
# Update mean
|
||||
task_mean = self._mean[task_id]
|
||||
batch_mean = torch.mean(x_task, dim=0)
|
||||
delta = batch_mean - task_mean
|
||||
self._mean[task_id] = task_mean + (batch_size / new_count) * delta
|
||||
|
||||
# Update variance using Chan's parallel algorithm
|
||||
if old_count > 0:
|
||||
batch_var = torch.var(x_task, dim=0, unbiased=False)
|
||||
m_a = self._var[task_id] * old_count
|
||||
m_b = batch_var * batch_size
|
||||
M2 = m_a + m_b + (delta**2) * (old_count * batch_size / new_count)
|
||||
self._var[task_id] = M2 / new_count
|
||||
else:
|
||||
# For the first batch of this task
|
||||
self._var[task_id] = torch.var(x_task, dim=0, unbiased=False)
|
||||
|
||||
self._std[task_id] = torch.sqrt(self._var[task_id])
|
||||
self.count[task_id] = new_count
|
||||
|
||||
|
||||
class PerTaskRewardNormalizer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_tasks: int,
|
||||
gamma: float,
|
||||
device: torch.device,
|
||||
g_max: float = 10.0,
|
||||
epsilon: float = 1e-8,
|
||||
):
|
||||
"""
|
||||
Per-task reward normalizer, motivation comes from BRC (https://arxiv.org/abs/2505.23150v1)
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_tasks = num_tasks
|
||||
self.gamma = gamma
|
||||
self.g_max = g_max
|
||||
self.epsilon = epsilon
|
||||
self.device = device
|
||||
|
||||
# Per-task running estimate of the discounted return
|
||||
self.register_buffer("G", torch.zeros(num_tasks, device=device))
|
||||
# Per-task running-max of the discounted return
|
||||
self.register_buffer("G_r_max", torch.zeros(num_tasks, device=device))
|
||||
# Use the new per-task normalizer for the statistics of G
|
||||
self.G_rms = PerTaskEmpiricalNormalization(
|
||||
num_tasks=num_tasks, shape=(1,), device=device
|
||||
)
|
||||
|
||||
def _scale_reward(
|
||||
self, rewards: torch.Tensor, task_ids: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Scales rewards using per-task statistics.
|
||||
|
||||
Args:
|
||||
rewards (torch.Tensor): Reward tensor, shape [num_envs].
|
||||
task_ids (torch.Tensor): Task indices, shape [num_envs].
|
||||
"""
|
||||
# Gather stats for the tasks in the batch
|
||||
std_for_batch = self.G_rms._std.gather(0, task_ids.unsqueeze(-1)).squeeze(-1)
|
||||
g_r_max_for_batch = self.G_r_max.gather(0, task_ids)
|
||||
|
||||
var_denominator = std_for_batch + self.epsilon
|
||||
min_required_denominator = g_r_max_for_batch / self.g_max
|
||||
denominator = torch.maximum(var_denominator, min_required_denominator)
|
||||
|
||||
# Add a small epsilon to the final denominator to prevent division by zero
|
||||
# in case g_r_max is also zero.
|
||||
return rewards / (denominator + self.epsilon)
|
||||
|
||||
def update_stats(
|
||||
self, rewards: torch.Tensor, dones: torch.Tensor, task_ids: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Updates the running discounted return and its statistics for each task.
|
||||
|
||||
Args:
|
||||
rewards (torch.Tensor): Reward tensor, shape [num_envs].
|
||||
dones (torch.Tensor): Done tensor, shape [num_envs].
|
||||
task_ids (torch.Tensor): Task indices, shape [num_envs].
|
||||
"""
|
||||
if not (rewards.shape == dones.shape == task_ids.shape):
|
||||
raise ValueError("rewards, dones, and task_ids must have the same shape.")
|
||||
|
||||
# === Update G (running discounted return) ===
|
||||
# Gather the previous G values for the tasks in the batch
|
||||
prev_G = self.G.gather(0, task_ids)
|
||||
# Update G for each environment based on its own reward and done signal
|
||||
new_G = self.gamma * (1 - dones.float()) * prev_G + rewards
|
||||
# Scatter the updated G values back to the main buffer
|
||||
self.G.scatter_(0, task_ids, new_G)
|
||||
|
||||
# === Update G_rms (statistics of G) ===
|
||||
# The update function handles the per-task logic internally
|
||||
self.G_rms.update(new_G.unsqueeze(-1), task_ids)
|
||||
|
||||
# === Update G_r_max (running max of |G|) ===
|
||||
prev_G_r_max = self.G_r_max.gather(0, task_ids)
|
||||
# Update the max for each environment
|
||||
updated_G_r_max = torch.maximum(prev_G_r_max, torch.abs(new_G))
|
||||
# Scatter the new maxes back to the main buffer
|
||||
self.G_r_max.scatter_(0, task_ids, updated_G_r_max)
|
||||
|
||||
def forward(self, rewards: torch.Tensor, task_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Normalizes rewards. During training, it also updates the running statistics.
|
||||
|
||||
Args:
|
||||
rewards (torch.Tensor): Reward tensor, shape [num_envs].
|
||||
task_ids (torch.Tensor): Task indices, shape [num_envs].
|
||||
"""
|
||||
return self._scale_reward(rewards, task_ids)
|
||||
|
||||
|
||||
def cpu_state(sd):
|
||||
# detach & move to host without locking the compute stream
|
||||
return {k: v.detach().to("cpu", non_blocking=True) for k, v in sd.items()}
|
||||
|
@ -95,7 +95,7 @@ class BaseArgs:
|
||||
obs_normalization: bool = True
|
||||
"""whether to enable observation normalization"""
|
||||
reward_normalization: bool = False
|
||||
"""whether to enable reward normalization (Not recommended for now, it's unstable.)"""
|
||||
"""whether to enable reward normalization"""
|
||||
max_grad_norm: float = 0.0
|
||||
"""the maximum gradient norm"""
|
||||
amp: bool = True
|
||||
@ -113,6 +113,8 @@ class BaseArgs:
|
||||
"""(Playground only) Use tuned reward for G1"""
|
||||
action_bounds: float = 1.0
|
||||
"""(IsaacLab only) the bounds of the action space (-action_bounds, action_bounds)"""
|
||||
task_embedding_dim: int = 32
|
||||
"""the dimension of the task embedding"""
|
||||
|
||||
weight_decay: float = 0.1
|
||||
"""the weight decay of the optimizer"""
|
||||
@ -169,6 +171,9 @@ def get_args():
|
||||
"Isaac-Velocity-Rough-G1-v0": IsaacVelocityRoughG1Args,
|
||||
"Isaac-Repose-Cube-Allegro-Direct-v0": IsaacReposeCubeAllegroDirectArgs,
|
||||
"Isaac-Repose-Cube-Shadow-Direct-v0": IsaacReposeCubeShadowDirectArgs,
|
||||
# MTBench
|
||||
"MTBench-meta-world-v2-mt10": MetaWorldMT10Args,
|
||||
"MTBench-meta-world-v2-mt50": MetaWorldMT50Args,
|
||||
}
|
||||
# If the provided env_name has a specific Args class, use it
|
||||
if base_args.env_name in env_to_args_class:
|
||||
@ -183,6 +188,9 @@ def get_args():
|
||||
elif base_args.env_name.startswith("Isaac-"):
|
||||
# IsaacLab
|
||||
specific_args = tyro.cli(IsaacLabArgs)
|
||||
elif base_args.env_name.startswith("MTBench-"):
|
||||
# MTBench
|
||||
specific_args = tyro.cli(MTBenchArgs)
|
||||
else:
|
||||
# MuJoCo Playground
|
||||
specific_args = tyro.cli(MuJoCoPlaygroundArgs)
|
||||
@ -280,6 +288,38 @@ class MuJoCoPlaygroundArgs(BaseArgs):
|
||||
gamma: float = 0.97
|
||||
|
||||
|
||||
@dataclass
|
||||
class MTBenchArgs(BaseArgs):
|
||||
# Default hyperparameters for MTBench
|
||||
reward_normalization: bool = True
|
||||
v_min: float = -10.0
|
||||
v_max: float = 10.0
|
||||
buffer_size: int = 2048 # 2K is usually enough for MTBench
|
||||
num_envs: int = 4096
|
||||
num_eval_envs: int = 4096
|
||||
gamma: float = 0.99
|
||||
num_steps: int = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetaWorldMT10Args(MTBenchArgs):
|
||||
# This config achieves 97 ~ 98% success rate within 10k steps (15-20 mins on A100)
|
||||
env_name: str = "MTBench-meta-world-v2-mt10"
|
||||
num_envs: int = 4096
|
||||
num_eval_envs: int = 4096
|
||||
num_steps: int = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetaWorldMT50Args(MTBenchArgs):
|
||||
# FastTD3 + SimbaV2 achieves >90% success rate within 20k steps (80 mins on A100)
|
||||
# Performance further improves with more training steps, slowly.
|
||||
env_name: str = "MTBench-meta-world-v2-mt50"
|
||||
num_envs: int = 8192
|
||||
num_eval_envs: int = 8192
|
||||
num_steps: int = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class G1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
|
||||
env_name: str = "G1JoystickFlatTerrain"
|
||||
|
@ -18,17 +18,24 @@ import tqdm
|
||||
import wandb
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
# Required for avoiding IsaacGym import error
|
||||
import isaacgym
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.amp import autocast, GradScaler
|
||||
|
||||
from tensordict import TensorDict, from_module
|
||||
from tensordict import TensorDict
|
||||
|
||||
from fast_td3_utils import (
|
||||
EmpiricalNormalization,
|
||||
RewardNormalizer,
|
||||
PerTaskRewardNormalizer,
|
||||
SimpleReplayBuffer,
|
||||
save_params,
|
||||
)
|
||||
@ -103,6 +110,14 @@ def main():
|
||||
)
|
||||
eval_envs = envs
|
||||
render_env = envs
|
||||
elif args.env_name.startswith("MTBench-"):
|
||||
from environments.mtbench_env import MTBenchEnv
|
||||
|
||||
env_name = "-".join(args.env_name.split("-")[1:])
|
||||
env_type = "mtbench"
|
||||
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
|
||||
eval_envs = envs
|
||||
render_env = envs
|
||||
else:
|
||||
from environments.mujoco_playground_env import make_env
|
||||
|
||||
@ -141,9 +156,19 @@ def main():
|
||||
critic_obs_normalizer = nn.Identity()
|
||||
|
||||
if args.reward_normalization:
|
||||
reward_normalizer = RewardNormalizer(
|
||||
gamma=args.gamma, device=device, g_max=min(abs(args.v_min), abs(args.v_max))
|
||||
)
|
||||
if env_type in ["mtbench"]:
|
||||
reward_normalizer = PerTaskRewardNormalizer(
|
||||
num_tasks=envs.num_tasks,
|
||||
gamma=args.gamma,
|
||||
device=device,
|
||||
g_max=min(abs(args.v_min), abs(args.v_max)),
|
||||
)
|
||||
else:
|
||||
reward_normalizer = RewardNormalizer(
|
||||
gamma=args.gamma,
|
||||
device=device,
|
||||
g_max=min(abs(args.v_min), abs(args.v_max)),
|
||||
)
|
||||
else:
|
||||
reward_normalizer = nn.Identity()
|
||||
|
||||
@ -165,12 +190,37 @@ def main():
|
||||
"device": device,
|
||||
}
|
||||
|
||||
if env_type == "mtbench":
|
||||
actor_kwargs["n_obs"] = n_obs - envs.num_tasks + args.task_embedding_dim
|
||||
critic_kwargs["n_obs"] = n_critic_obs - envs.num_tasks + args.task_embedding_dim
|
||||
actor_kwargs["num_tasks"] = envs.num_tasks
|
||||
actor_kwargs["task_embedding_dim"] = args.task_embedding_dim
|
||||
critic_kwargs["num_tasks"] = envs.num_tasks
|
||||
critic_kwargs["task_embedding_dim"] = args.task_embedding_dim
|
||||
|
||||
if args.agent == "fasttd3":
|
||||
from fast_td3 import Actor, Critic
|
||||
if env_type in ["mtbench"]:
|
||||
from fast_td3 import MultiTaskActor, MultiTaskCritic
|
||||
|
||||
actor_cls = MultiTaskActor
|
||||
critic_cls = MultiTaskCritic
|
||||
else:
|
||||
from fast_td3 import Actor, Critic
|
||||
|
||||
actor_cls = Actor
|
||||
critic_cls = Critic
|
||||
|
||||
print("Using FastTD3")
|
||||
elif args.agent == "fasttd3_simbav2":
|
||||
from fast_td3_simbav2 import Actor, Critic
|
||||
if env_type in ["mtbench"]:
|
||||
from fast_td3_simbav2 import MultiTaskActor, MultiTaskCritic
|
||||
|
||||
actor_cls = MultiTaskActor
|
||||
critic_cls = MultiTaskCritic
|
||||
else:
|
||||
from fast_td3_simbav2 import Actor, Critic
|
||||
|
||||
actor_cls = Actor
|
||||
|
||||
print("Using FastTD3 + SimbaV2")
|
||||
actor_kwargs.pop("init_scale")
|
||||
@ -199,25 +249,31 @@ def main():
|
||||
else:
|
||||
raise ValueError(f"Agent {args.agent} not supported")
|
||||
|
||||
actor = Actor(**actor_kwargs)
|
||||
actor_detach = Actor(**actor_kwargs)
|
||||
actor = actor_cls(**actor_kwargs)
|
||||
|
||||
# Copy params to actor_detach without grad
|
||||
from_module(actor).data.to_module(actor_detach)
|
||||
policy = actor_detach.explore
|
||||
if env_type in ["mtbench"]:
|
||||
# Python 3.8 doesn't support 'from_module' in tensordict
|
||||
policy = actor.explore
|
||||
else:
|
||||
from tensordict import from_module
|
||||
|
||||
qnet = Critic(**critic_kwargs)
|
||||
qnet_target = Critic(**critic_kwargs)
|
||||
actor_detach = actor_cls(**actor_kwargs)
|
||||
# Copy params to actor_detach without grad
|
||||
from_module(actor).data.to_module(actor_detach)
|
||||
policy = actor_detach.explore
|
||||
|
||||
qnet = critic_cls(**critic_kwargs)
|
||||
qnet_target = critic_cls(**critic_kwargs)
|
||||
qnet_target.load_state_dict(qnet.state_dict())
|
||||
|
||||
q_optimizer = optim.AdamW(
|
||||
list(qnet.parameters()),
|
||||
lr=args.critic_learning_rate,
|
||||
lr=torch.tensor(args.critic_learning_rate, device=device),
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
actor_optimizer = optim.AdamW(
|
||||
list(actor.parameters()),
|
||||
lr=args.actor_learning_rate,
|
||||
lr=torch.tensor(args.actor_learning_rate, device=device),
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
|
||||
@ -225,12 +281,12 @@ def main():
|
||||
q_scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
q_optimizer,
|
||||
T_max=args.total_timesteps,
|
||||
eta_min=args.critic_learning_rate_end, # Decay to 10% of initial lr
|
||||
eta_min=torch.tensor(args.critic_learning_rate_end, device=device),
|
||||
)
|
||||
actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
actor_optimizer,
|
||||
T_max=args.total_timesteps,
|
||||
eta_min=args.actor_learning_rate_end, # Decay to 10% of initial lr
|
||||
eta_min=torch.tensor(args.actor_learning_rate_end, device=device),
|
||||
)
|
||||
|
||||
rb = SimpleReplayBuffer(
|
||||
@ -262,20 +318,28 @@ def main():
|
||||
obs = eval_envs.reset()
|
||||
|
||||
# Run for a fixed number of steps
|
||||
for _ in range(eval_envs.max_episode_steps):
|
||||
for i in range(eval_envs.max_episode_steps):
|
||||
with torch.no_grad(), autocast(
|
||||
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
|
||||
):
|
||||
obs = normalize_obs(obs)
|
||||
actions = actor(obs)
|
||||
|
||||
next_obs, rewards, dones, _ = eval_envs.step(actions.float())
|
||||
next_obs, rewards, dones, infos = eval_envs.step(actions.float())
|
||||
|
||||
if env_type == "mtbench":
|
||||
# We only report success rate in MTBench evaluation
|
||||
rewards = (
|
||||
infos["episode"]["success"].float() if "episode" in infos else 0.0
|
||||
)
|
||||
episode_returns = torch.where(
|
||||
~done_masks, episode_returns + rewards, episode_returns
|
||||
)
|
||||
episode_lengths = torch.where(
|
||||
~done_masks, episode_lengths + 1, episode_lengths
|
||||
)
|
||||
if env_type == "mtbench" and "episode" in infos:
|
||||
dones = dones | infos["episode"]["success"]
|
||||
done_masks = torch.logical_or(done_masks, dones)
|
||||
if done_masks.all():
|
||||
break
|
||||
@ -291,9 +355,9 @@ def main():
|
||||
if env_type == "humanoid_bench":
|
||||
obs = render_env.reset()
|
||||
renders = [render_env.render()]
|
||||
elif env_type == "isaaclab":
|
||||
elif env_type in ["isaaclab", "mtbench"]:
|
||||
raise NotImplementedError(
|
||||
"We don't support rendering for IsaacLab environments"
|
||||
"We don't support rendering for IsaacLab and MTBench environments"
|
||||
)
|
||||
else:
|
||||
obs = render_env.reset()
|
||||
@ -399,6 +463,7 @@ def main():
|
||||
)
|
||||
scaler.step(q_optimizer)
|
||||
scaler.update()
|
||||
q_scheduler.step()
|
||||
|
||||
logs_dict["critic_grad_norm"] = critic_grad_norm.detach()
|
||||
logs_dict["qf_loss"] = qf_loss.detach()
|
||||
@ -434,6 +499,7 @@ def main():
|
||||
)
|
||||
scaler.step(actor_optimizer)
|
||||
scaler.update()
|
||||
actor_scheduler.step()
|
||||
logs_dict["actor_grad_norm"] = actor_grad_norm.detach()
|
||||
logs_dict["actor_loss"] = actor_loss.detach()
|
||||
return logs_dict
|
||||
@ -500,7 +566,12 @@ def main():
|
||||
truncations = infos["time_outs"]
|
||||
|
||||
if args.reward_normalization:
|
||||
update_stats(rewards, dones.float())
|
||||
if env_type == "mtbench":
|
||||
task_ids_one_hot = obs[..., -envs.num_tasks :]
|
||||
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||
update_stats(rewards, dones.float(), task_ids=task_indices)
|
||||
else:
|
||||
update_stats(rewards, dones.float())
|
||||
|
||||
if envs.asymmetric_obs:
|
||||
next_critic_obs = infos["observations"]["critic"]
|
||||
@ -550,7 +621,15 @@ def main():
|
||||
data["next"]["observations"]
|
||||
)
|
||||
raw_rewards = data["next"]["rewards"]
|
||||
data["next"]["rewards"] = normalize_reward(raw_rewards)
|
||||
if env_type in ["mtbench"] and args.reward_normalization:
|
||||
# Multi-task reward normalization
|
||||
task_ids_one_hot = data["observations"][..., -envs.num_tasks :]
|
||||
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||
data["next"]["rewards"] = normalize_reward(
|
||||
raw_rewards, task_ids=task_indices
|
||||
)
|
||||
else:
|
||||
data["next"]["rewards"] = normalize_reward(raw_rewards)
|
||||
if envs.asymmetric_obs:
|
||||
data["critic_observations"] = normalize_critic_obs(
|
||||
data["critic_observations"]
|
||||
@ -591,7 +670,7 @@ def main():
|
||||
if args.eval_interval > 0 and global_step % args.eval_interval == 0:
|
||||
print(f"Evaluating at global step {global_step}")
|
||||
eval_avg_return, eval_avg_length = evaluate()
|
||||
if env_type in ["humanoid_bench", "isaaclab"]:
|
||||
if env_type in ["humanoid_bench", "isaaclab", "mtbench"]:
|
||||
# NOTE: Hacky way of evaluating performance, but just works
|
||||
obs = envs.reset()
|
||||
logs["eval_avg_return"] = eval_avg_return
|
||||
@ -644,10 +723,6 @@ def main():
|
||||
f"models/{run_name}_{global_step}.pt",
|
||||
)
|
||||
|
||||
# Update learning rates
|
||||
q_scheduler.step()
|
||||
actor_scheduler.step()
|
||||
|
||||
global_step += 1
|
||||
pbar.update(1)
|
||||
|
||||
|
15
requirements/requirements_isaacgym.txt
Normal file
15
requirements/requirements_isaacgym.txt
Normal file
@ -0,0 +1,15 @@
|
||||
gymnasium<1.0.0
|
||||
jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11"
|
||||
matplotlib
|
||||
moviepy
|
||||
numpy<2.0
|
||||
pandas
|
||||
protobuf
|
||||
pygame
|
||||
stable-baselines3
|
||||
tqdm
|
||||
wandb
|
||||
torchrl==0.5.0
|
||||
tensordict==0.5.0
|
||||
tyro
|
||||
loguru
|
Loading…
Reference in New Issue
Block a user