Support MTBench (#15)
This PR incorporates MTBench into the current codebase, as a good demonstration that shows how to use FastTD3 for multi-task setup. - Add support for MTBench along with its wrapper - Add support for per-task reward normalizer useful for multi-task RL, motivated by BRC paper (https://arxiv.org/abs/2505.23150v1)
This commit is contained in:
parent
3facede77d
commit
cef44108d8
67
README.md
67
README.md
@ -10,11 +10,13 @@ For more information, please see our [project webpage](https://younggyo.me/fast_
|
|||||||
|
|
||||||
|
|
||||||
## ❗ Updates
|
## ❗ Updates
|
||||||
- **[June/15/2025]** Added support for FastTD3 + SimbaV2! It's faster to train, and often achieves better asymptotic performance.
|
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench)
|
||||||
|
|
||||||
- **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
|
- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases.
|
||||||
|
|
||||||
- **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
|
- **[Jun/06/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
|
||||||
|
|
||||||
|
- **[Jun/01/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
|
||||||
|
|
||||||
|
|
||||||
## ✨ Features
|
## ✨ Features
|
||||||
@ -81,6 +83,27 @@ cd ..
|
|||||||
pip install -r requirements/requirements.txt
|
pip install -r requirements/requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Environment for MTBench
|
||||||
|
MTBench does not support humanoid experiments, but is a useful multi-task benchmark with massive parallel simulation. This could be useful for users who want to use FastTD3 for their multi-task experiments.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -n fasttd3_mtbench -y python=3.8 # Note python version
|
||||||
|
conda activate fasttd3_mtbench
|
||||||
|
|
||||||
|
# Install IsaacGym -- recommend to follow instructions in https://github.com/BoosterRobotics/booster_gym
|
||||||
|
...
|
||||||
|
|
||||||
|
# Install MTBench
|
||||||
|
git clone https://github.com/Viraj-Joshi/MTBench.git
|
||||||
|
cd MTbench
|
||||||
|
pip install -e .
|
||||||
|
pip install skrl
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
# Install project-specific requirements
|
||||||
|
pip install -r requirements/requirements_isaacgym.txt
|
||||||
|
```
|
||||||
|
|
||||||
### (Optional) Accelerate headless GPU rendering in cloud instances
|
### (Optional) Accelerate headless GPU rendering in cloud instances
|
||||||
|
|
||||||
In some cloud VM images the NVIDIA kernel driver is present but the user-space OpenGL/EGL/Vulkan libraries aren't, so MuJoCo falls back to CPU renderer. You can install just the NVIDIA user-space libraries (and skip rebuilding the kernel module) with:
|
In some cloud VM images the NVIDIA kernel driver is present but the user-space OpenGL/EGL/Vulkan libraries aren't, so MuJoCo falls back to CPU renderer. You can install just the NVIDIA user-space libraries (and skip rebuilding the kernel module) with:
|
||||||
@ -176,6 +199,32 @@ python fast_td3/train.py \
|
|||||||
--seed 1
|
--seed 1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### MTBench Experiments
|
||||||
|
```bash
|
||||||
|
conda activate fasttd3_mtbench
|
||||||
|
# FastTD3
|
||||||
|
python fast_td3/train.py \
|
||||||
|
--env_name MTBench-meta-world-v2-mt10 \
|
||||||
|
--exp_name FastTD3 \
|
||||||
|
--render_interval 0 \
|
||||||
|
--seed 1
|
||||||
|
# FastTD3 + SimbaV2
|
||||||
|
python fast_td3/train.py \
|
||||||
|
--env_name MTBench-meta-world-v2-mt10 \
|
||||||
|
--exp_name FastTD3 \
|
||||||
|
--render_interval 0 \
|
||||||
|
--agent fasttd3_simbav2 \
|
||||||
|
--batch_size 8192 \
|
||||||
|
--critic_learning_rate_end 3e-5 \
|
||||||
|
--actor_learning_rate_end 3e-5 \
|
||||||
|
--weight_decay 0.0 \
|
||||||
|
--critic_hidden_dim 512 \
|
||||||
|
--critic_num_blocks 2 \
|
||||||
|
--actor_hidden_dim 256 \
|
||||||
|
--actor_num_blocks 1 \
|
||||||
|
--seed 1
|
||||||
|
```
|
||||||
|
|
||||||
**Quick note:** For boolean-based arguments, you can set them to False by adding `no_` in front each argument, for instance, if you want to disable Clipped Q Learning, you can specify `--no_use_cdq` in your command.
|
**Quick note:** For boolean-based arguments, you can set them to False by adding `no_` in front each argument, for instance, if you want to disable Clipped Q Learning, you can specify `--no_use_cdq` in your command.
|
||||||
|
|
||||||
## 💡 Performance-Related Tips
|
## 💡 Performance-Related Tips
|
||||||
@ -315,6 +364,18 @@ Following the [LeanRL](https://github.com/pytorch-labs/LeanRL)'s recommendation,
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### MTBench
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{
|
||||||
|
joshi2025benchmarking,
|
||||||
|
title={Benchmarking Massively Parallelized Multi-Task Reinforcement Learning for Robotics Tasks},
|
||||||
|
author={Viraj Joshi and Zifan Xu and Bo Liu and Peter Stone and Amy Zhang},
|
||||||
|
booktitle={Reinforcement Learning Conference},
|
||||||
|
year={2025},
|
||||||
|
url={https://openreview.net/forum?id=z0MM0y20I2}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### Getting SAC to Work on a Massive Parallel Simulator
|
### Getting SAC to Work on a Massive Parallel Simulator
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{raffin2025isaacsim,
|
@article{raffin2025isaacsim,
|
||||||
|
148
fast_td3/environments/mtbench_env.py
Normal file
148
fast_td3/environments/mtbench_env.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
import isaacgym
|
||||||
|
import isaacgymenvs
|
||||||
|
|
||||||
|
|
||||||
|
class MTBenchEnv:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_name: str,
|
||||||
|
device_id: int,
|
||||||
|
num_envs: int,
|
||||||
|
seed: int,
|
||||||
|
):
|
||||||
|
# NOTE: Currently, we only support Meta-World-v2 MT-10/MT-50 in MTBench
|
||||||
|
task_config = MTBENCH_MW2_CONFIG.copy()
|
||||||
|
if task_name == "meta-world-v2-mt10":
|
||||||
|
# MT-10 Setup
|
||||||
|
assert num_envs == 4096, "MT-10 only supports 4096 environments (for now)"
|
||||||
|
self.num_tasks = 10
|
||||||
|
task_config["env"]["tasks"] = [4, 16, 17, 18, 28, 31, 38, 40, 48, 49]
|
||||||
|
task_config["env"]["taskEnvCount"] = [410] * 6 + [409] * 4
|
||||||
|
elif task_name == "meta-world-v2-mt50":
|
||||||
|
# MT-50 Setup
|
||||||
|
self.num_tasks = 50
|
||||||
|
assert num_envs == 8192, "MT-50 only supports 8192 environments (for now)"
|
||||||
|
task_config["env"]["tasks"] = list(range(50))
|
||||||
|
task_config["env"]["taskEnvCount"] = [164] * 42 + [163] * 8 # 6888 + 1304
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported task name: {task_name}")
|
||||||
|
task_config["env"]["numEnvs"] = num_envs
|
||||||
|
task_config["env"]["numObservations"] = 39 + self.num_tasks
|
||||||
|
task_config["env"]["seed"] = seed
|
||||||
|
|
||||||
|
# Convert dictionary to OmegaConf object
|
||||||
|
env_cfg = {"task": task_config}
|
||||||
|
env_cfg = OmegaConf.create(env_cfg)
|
||||||
|
|
||||||
|
self.env = isaacgymenvs.make(
|
||||||
|
task=env_cfg.task.name,
|
||||||
|
num_envs=num_envs,
|
||||||
|
sim_device=f"cuda:{device_id}",
|
||||||
|
rl_device=f"cuda:{device_id}",
|
||||||
|
seed=seed,
|
||||||
|
headless=True,
|
||||||
|
cfg=env_cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_envs = num_envs
|
||||||
|
self.asymmetric_obs = False
|
||||||
|
self.num_obs = self.env.observation_space.shape[0]
|
||||||
|
assert (
|
||||||
|
self.num_obs == 39 + self.num_tasks
|
||||||
|
), "MTBench observation space is 39 + num_tasks (one-hot vector)"
|
||||||
|
self.num_privileged_obs = 0
|
||||||
|
self.num_actions = self.env.action_space.shape[0]
|
||||||
|
self.max_episode_steps = self.env.max_episode_length
|
||||||
|
|
||||||
|
def reset(self) -> torch.Tensor:
|
||||||
|
"""Reset the environment."""
|
||||||
|
# TODO: Check if we need no_grad and detach here
|
||||||
|
with torch.no_grad(): # do we need this?
|
||||||
|
self.env.reset_idx(torch.arange(self.num_envs, device=self.env.device))
|
||||||
|
obs_dict = self.env.reset()
|
||||||
|
return obs_dict["obs"].detach()
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, actions: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
|
||||||
|
"""Step the environment."""
|
||||||
|
assert isinstance(actions, torch.Tensor)
|
||||||
|
|
||||||
|
# TODO: Check if we need no_grad and detach here
|
||||||
|
with torch.no_grad():
|
||||||
|
obs_dict, rew, dones, infos = self.env.step(actions.detach())
|
||||||
|
truncations = infos["time_outs"]
|
||||||
|
info_ret = {"time_outs": truncations.detach()}
|
||||||
|
if "episode" in infos:
|
||||||
|
info_ret["episode"] = infos["episode"]
|
||||||
|
# NOTE: There's really no way to get the raw observations from IsaacGym
|
||||||
|
# We just use the 'reset_obs' as next_obs, unfortunately.
|
||||||
|
info_ret["observations"] = {"raw": {"obs": obs_dict["obs"].detach()}}
|
||||||
|
return obs_dict["obs"].detach(), rew.detach(), dones.detach(), info_ret
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"We don't support rendering for IsaacLab environments"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MTBENCH_MW2_CONFIG = {
|
||||||
|
"name": "meta-world-v2",
|
||||||
|
"physics_engine": "physx",
|
||||||
|
"env": {
|
||||||
|
"numEnvs": 1,
|
||||||
|
"envSpacing": 1.5,
|
||||||
|
"episodeLength": 150,
|
||||||
|
"enableDebugVis": False,
|
||||||
|
"clipObservations": 5.0,
|
||||||
|
"clipActions": 1.0,
|
||||||
|
"aggregateMode": 3,
|
||||||
|
"actionScale": 0.01,
|
||||||
|
"resetNoise": 0.15,
|
||||||
|
"tasks": [0],
|
||||||
|
"taskEnvCount": [4096],
|
||||||
|
"init_at_random_progress": True,
|
||||||
|
"exemptedInitAtRandomProgressTasks": [],
|
||||||
|
"taskEmbedding": True,
|
||||||
|
"taskEmbeddingType": "one_hot",
|
||||||
|
"seed": 42,
|
||||||
|
"cameraRenderingInterval": 5000,
|
||||||
|
"cameraWidth": 1024,
|
||||||
|
"cameraHeight": 1024,
|
||||||
|
"sparse_reward": False,
|
||||||
|
"termination_on_success": False,
|
||||||
|
"reward_scale": 1.0,
|
||||||
|
"fixed": False,
|
||||||
|
"numObservations": None,
|
||||||
|
"numActions": 4,
|
||||||
|
},
|
||||||
|
"enableCameraSensors": False,
|
||||||
|
"sim": {
|
||||||
|
"dt": 0.01667,
|
||||||
|
"substeps": 2,
|
||||||
|
"up_axis": "z",
|
||||||
|
"use_gpu_pipeline": True,
|
||||||
|
"gravity": [0.0, 0.0, -9.81],
|
||||||
|
"physx": {
|
||||||
|
"num_threads": 4,
|
||||||
|
"solver_type": 1,
|
||||||
|
"use_gpu": True,
|
||||||
|
"num_position_iterations": 8,
|
||||||
|
"num_velocity_iterations": 1,
|
||||||
|
"contact_offset": 0.005,
|
||||||
|
"rest_offset": 0.0,
|
||||||
|
"bounce_threshold_velocity": 0.2,
|
||||||
|
"max_depenetration_velocity": 1000.0,
|
||||||
|
"default_buffer_size_multiplier": 10.0,
|
||||||
|
"max_gpu_contact_pairs": 1048576,
|
||||||
|
"num_subscenes": 4,
|
||||||
|
"contact_collection": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"task": {"randomize": False},
|
||||||
|
}
|
@ -114,6 +114,7 @@ class Critic(nn.Module):
|
|||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"q_support", torch.linspace(v_min, v_max, num_atoms, device=device)
|
"q_support", torch.linspace(v_min, v_max, num_atoms, device=device)
|
||||||
)
|
)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||||
return self.qnet1(obs, actions), self.qnet2(obs, actions)
|
return self.qnet1(obs, actions), self.qnet2(obs, actions)
|
||||||
@ -189,6 +190,7 @@ class Actor(nn.Module):
|
|||||||
self.register_buffer("std_min", torch.as_tensor(std_min, device=device))
|
self.register_buffer("std_min", torch.as_tensor(std_min, device=device))
|
||||||
self.register_buffer("std_max", torch.as_tensor(std_max, device=device))
|
self.register_buffer("std_max", torch.as_tensor(std_max, device=device))
|
||||||
self.n_envs = num_envs
|
self.n_envs = num_envs
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||||
x = obs
|
x = obs
|
||||||
@ -218,3 +220,51 @@ class Actor(nn.Module):
|
|||||||
|
|
||||||
noise = torch.randn_like(act) * self.noise_scales
|
noise = torch.randn_like(act) * self.noise_scales
|
||||||
return act + noise
|
return act + noise
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTaskActor(Actor):
|
||||||
|
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_tasks = num_tasks
|
||||||
|
self.task_embedding_dim = task_embedding_dim
|
||||||
|
self.task_embedding = nn.Embedding(
|
||||||
|
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||||
|
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||||
|
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||||
|
task_embeddings = self.task_embedding(task_indices)
|
||||||
|
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||||
|
return super().forward(obs)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTaskCritic(Critic):
|
||||||
|
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_tasks = num_tasks
|
||||||
|
self.task_embedding_dim = task_embedding_dim
|
||||||
|
self.task_embedding = nn.Embedding(
|
||||||
|
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||||
|
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||||
|
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||||
|
task_embeddings = self.task_embedding(task_indices)
|
||||||
|
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||||
|
return super().forward(obs, actions)
|
||||||
|
|
||||||
|
def projection(
|
||||||
|
self,
|
||||||
|
obs: torch.Tensor,
|
||||||
|
actions: torch.Tensor,
|
||||||
|
rewards: torch.Tensor,
|
||||||
|
bootstrap: torch.Tensor,
|
||||||
|
discount: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||||
|
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||||
|
task_embeddings = self.task_embedding(task_indices)
|
||||||
|
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||||
|
return super().projection(obs, actions, rewards, bootstrap, discount)
|
||||||
|
@ -365,6 +365,7 @@ class Critic(nn.Module):
|
|||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"q_support", torch.linspace(v_min, v_max, num_atoms, device=device)
|
"q_support", torch.linspace(v_min, v_max, num_atoms, device=device)
|
||||||
)
|
)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||||
return self.qnet1(obs, actions), self.qnet2(obs, actions)
|
return self.qnet1(obs, actions), self.qnet2(obs, actions)
|
||||||
@ -449,8 +450,8 @@ class Actor(nn.Module):
|
|||||||
self.predictor = HyperTanhPolicy(
|
self.predictor = HyperTanhPolicy(
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
action_dim=n_act,
|
action_dim=n_act,
|
||||||
scaler_init=scaler_init,
|
scaler_init=1.0,
|
||||||
scaler_scale=scaler_scale,
|
scaler_scale=1.0,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -462,6 +463,7 @@ class Actor(nn.Module):
|
|||||||
self.register_buffer("std_min", torch.as_tensor(std_min, device=device))
|
self.register_buffer("std_min", torch.as_tensor(std_min, device=device))
|
||||||
self.register_buffer("std_max", torch.as_tensor(std_max, device=device))
|
self.register_buffer("std_max", torch.as_tensor(std_max, device=device))
|
||||||
self.n_envs = num_envs
|
self.n_envs = num_envs
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||||
x = obs
|
x = obs
|
||||||
@ -492,3 +494,51 @@ class Actor(nn.Module):
|
|||||||
|
|
||||||
noise = torch.randn_like(act) * self.noise_scales
|
noise = torch.randn_like(act) * self.noise_scales
|
||||||
return act + noise
|
return act + noise
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTaskActor(Actor):
|
||||||
|
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_tasks = num_tasks
|
||||||
|
self.task_embedding_dim = task_embedding_dim
|
||||||
|
self.task_embedding = nn.Embedding(
|
||||||
|
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||||
|
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||||
|
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||||
|
task_embeddings = self.task_embedding(task_indices)
|
||||||
|
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||||
|
return super().forward(obs)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTaskCritic(Critic):
|
||||||
|
def __init__(self, num_tasks: int, task_embedding_dim: int, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_tasks = num_tasks
|
||||||
|
self.task_embedding_dim = task_embedding_dim
|
||||||
|
self.task_embedding = nn.Embedding(
|
||||||
|
num_tasks, task_embedding_dim, max_norm=1.0, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||||
|
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||||
|
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||||
|
task_embeddings = self.task_embedding(task_indices)
|
||||||
|
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||||
|
return super().forward(obs, actions)
|
||||||
|
|
||||||
|
def projection(
|
||||||
|
self,
|
||||||
|
obs: torch.Tensor,
|
||||||
|
actions: torch.Tensor,
|
||||||
|
rewards: torch.Tensor,
|
||||||
|
bootstrap: torch.Tensor,
|
||||||
|
discount: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
task_ids_one_hot = obs[..., -self.num_tasks :]
|
||||||
|
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||||
|
task_embeddings = self.task_embedding(task_indices)
|
||||||
|
obs = torch.cat([obs[..., : -self.num_tasks], task_embeddings], dim=-1)
|
||||||
|
return super().projection(obs, actions, rewards, bootstrap, discount)
|
||||||
|
@ -512,6 +512,212 @@ class RewardNormalizer(nn.Module):
|
|||||||
return self._scale_reward(rewards)
|
return self._scale_reward(rewards)
|
||||||
|
|
||||||
|
|
||||||
|
class PerTaskEmpiricalNormalization(nn.Module):
|
||||||
|
"""Normalize mean and variance of values based on empirical values for each task."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_tasks: int,
|
||||||
|
shape: tuple,
|
||||||
|
device: torch.device,
|
||||||
|
eps: float = 1e-2,
|
||||||
|
until: int = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize PerTaskEmpiricalNormalization module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_tasks (int): The total number of tasks.
|
||||||
|
shape (int or tuple of int): Shape of input values except batch axis.
|
||||||
|
eps (float): Small value for stability.
|
||||||
|
until (int or None): If specified, learns until the sum of batch sizes
|
||||||
|
for a specific task exceeds this value.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if not isinstance(shape, tuple):
|
||||||
|
shape = (shape,)
|
||||||
|
self.num_tasks = num_tasks
|
||||||
|
self.shape = shape
|
||||||
|
self.eps = eps
|
||||||
|
self.until = until
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Buffers now have a leading dimension for tasks
|
||||||
|
self.register_buffer("_mean", torch.zeros(num_tasks, *shape).to(device))
|
||||||
|
self.register_buffer("_var", torch.ones(num_tasks, *shape).to(device))
|
||||||
|
self.register_buffer("_std", torch.ones(num_tasks, *shape).to(device))
|
||||||
|
self.register_buffer(
|
||||||
|
"count", torch.zeros(num_tasks, dtype=torch.long).to(device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, task_ids: torch.Tensor, center: bool = True
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalize the input tensor `x` using statistics for the given `task_ids`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of shape [num_envs, *shape].
|
||||||
|
task_ids (torch.Tensor): Tensor of task indices, shape [num_envs].
|
||||||
|
center (bool): If True, center the data by subtracting the mean.
|
||||||
|
"""
|
||||||
|
if x.shape[1:] != self.shape:
|
||||||
|
raise ValueError(f"Expected input shape (*, {self.shape}), got {x.shape}")
|
||||||
|
if x.shape[0] != task_ids.shape[0]:
|
||||||
|
raise ValueError("Batch size of x and task_ids must match.")
|
||||||
|
|
||||||
|
# Gather the stats for the tasks in the current batch
|
||||||
|
# Reshape task_ids for broadcasting: [num_envs] -> [num_envs, 1, ...]
|
||||||
|
view_shape = (task_ids.shape[0],) + (1,) * len(self.shape)
|
||||||
|
task_ids_expanded = task_ids.view(view_shape).expand_as(x)
|
||||||
|
|
||||||
|
mean = self._mean.gather(0, task_ids_expanded)
|
||||||
|
std = self._std.gather(0, task_ids_expanded)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
self.update(x, task_ids)
|
||||||
|
|
||||||
|
if center:
|
||||||
|
return (x - mean) / (std + self.eps)
|
||||||
|
else:
|
||||||
|
return x / (std + self.eps)
|
||||||
|
|
||||||
|
@torch.jit.unused
|
||||||
|
def update(self, x: torch.Tensor, task_ids: torch.Tensor):
|
||||||
|
"""Update running statistics for the tasks present in the batch."""
|
||||||
|
unique_tasks = torch.unique(task_ids)
|
||||||
|
|
||||||
|
for task_id in unique_tasks:
|
||||||
|
if self.until is not None and self.count[task_id] >= self.until:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create a mask to select data for the current task
|
||||||
|
mask = task_ids == task_id
|
||||||
|
x_task = x[mask]
|
||||||
|
batch_size = x_task.shape[0]
|
||||||
|
|
||||||
|
if batch_size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update count for this task
|
||||||
|
old_count = self.count[task_id].clone()
|
||||||
|
new_count = old_count + batch_size
|
||||||
|
|
||||||
|
# Update mean
|
||||||
|
task_mean = self._mean[task_id]
|
||||||
|
batch_mean = torch.mean(x_task, dim=0)
|
||||||
|
delta = batch_mean - task_mean
|
||||||
|
self._mean[task_id] = task_mean + (batch_size / new_count) * delta
|
||||||
|
|
||||||
|
# Update variance using Chan's parallel algorithm
|
||||||
|
if old_count > 0:
|
||||||
|
batch_var = torch.var(x_task, dim=0, unbiased=False)
|
||||||
|
m_a = self._var[task_id] * old_count
|
||||||
|
m_b = batch_var * batch_size
|
||||||
|
M2 = m_a + m_b + (delta**2) * (old_count * batch_size / new_count)
|
||||||
|
self._var[task_id] = M2 / new_count
|
||||||
|
else:
|
||||||
|
# For the first batch of this task
|
||||||
|
self._var[task_id] = torch.var(x_task, dim=0, unbiased=False)
|
||||||
|
|
||||||
|
self._std[task_id] = torch.sqrt(self._var[task_id])
|
||||||
|
self.count[task_id] = new_count
|
||||||
|
|
||||||
|
|
||||||
|
class PerTaskRewardNormalizer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_tasks: int,
|
||||||
|
gamma: float,
|
||||||
|
device: torch.device,
|
||||||
|
g_max: float = 10.0,
|
||||||
|
epsilon: float = 1e-8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Per-task reward normalizer, motivation comes from BRC (https://arxiv.org/abs/2505.23150v1)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_tasks = num_tasks
|
||||||
|
self.gamma = gamma
|
||||||
|
self.g_max = g_max
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Per-task running estimate of the discounted return
|
||||||
|
self.register_buffer("G", torch.zeros(num_tasks, device=device))
|
||||||
|
# Per-task running-max of the discounted return
|
||||||
|
self.register_buffer("G_r_max", torch.zeros(num_tasks, device=device))
|
||||||
|
# Use the new per-task normalizer for the statistics of G
|
||||||
|
self.G_rms = PerTaskEmpiricalNormalization(
|
||||||
|
num_tasks=num_tasks, shape=(1,), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
def _scale_reward(
|
||||||
|
self, rewards: torch.Tensor, task_ids: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Scales rewards using per-task statistics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rewards (torch.Tensor): Reward tensor, shape [num_envs].
|
||||||
|
task_ids (torch.Tensor): Task indices, shape [num_envs].
|
||||||
|
"""
|
||||||
|
# Gather stats for the tasks in the batch
|
||||||
|
std_for_batch = self.G_rms._std.gather(0, task_ids.unsqueeze(-1)).squeeze(-1)
|
||||||
|
g_r_max_for_batch = self.G_r_max.gather(0, task_ids)
|
||||||
|
|
||||||
|
var_denominator = std_for_batch + self.epsilon
|
||||||
|
min_required_denominator = g_r_max_for_batch / self.g_max
|
||||||
|
denominator = torch.maximum(var_denominator, min_required_denominator)
|
||||||
|
|
||||||
|
# Add a small epsilon to the final denominator to prevent division by zero
|
||||||
|
# in case g_r_max is also zero.
|
||||||
|
return rewards / (denominator + self.epsilon)
|
||||||
|
|
||||||
|
def update_stats(
|
||||||
|
self, rewards: torch.Tensor, dones: torch.Tensor, task_ids: torch.Tensor
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Updates the running discounted return and its statistics for each task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rewards (torch.Tensor): Reward tensor, shape [num_envs].
|
||||||
|
dones (torch.Tensor): Done tensor, shape [num_envs].
|
||||||
|
task_ids (torch.Tensor): Task indices, shape [num_envs].
|
||||||
|
"""
|
||||||
|
if not (rewards.shape == dones.shape == task_ids.shape):
|
||||||
|
raise ValueError("rewards, dones, and task_ids must have the same shape.")
|
||||||
|
|
||||||
|
# === Update G (running discounted return) ===
|
||||||
|
# Gather the previous G values for the tasks in the batch
|
||||||
|
prev_G = self.G.gather(0, task_ids)
|
||||||
|
# Update G for each environment based on its own reward and done signal
|
||||||
|
new_G = self.gamma * (1 - dones.float()) * prev_G + rewards
|
||||||
|
# Scatter the updated G values back to the main buffer
|
||||||
|
self.G.scatter_(0, task_ids, new_G)
|
||||||
|
|
||||||
|
# === Update G_rms (statistics of G) ===
|
||||||
|
# The update function handles the per-task logic internally
|
||||||
|
self.G_rms.update(new_G.unsqueeze(-1), task_ids)
|
||||||
|
|
||||||
|
# === Update G_r_max (running max of |G|) ===
|
||||||
|
prev_G_r_max = self.G_r_max.gather(0, task_ids)
|
||||||
|
# Update the max for each environment
|
||||||
|
updated_G_r_max = torch.maximum(prev_G_r_max, torch.abs(new_G))
|
||||||
|
# Scatter the new maxes back to the main buffer
|
||||||
|
self.G_r_max.scatter_(0, task_ids, updated_G_r_max)
|
||||||
|
|
||||||
|
def forward(self, rewards: torch.Tensor, task_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalizes rewards. During training, it also updates the running statistics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rewards (torch.Tensor): Reward tensor, shape [num_envs].
|
||||||
|
task_ids (torch.Tensor): Task indices, shape [num_envs].
|
||||||
|
"""
|
||||||
|
return self._scale_reward(rewards, task_ids)
|
||||||
|
|
||||||
|
|
||||||
def cpu_state(sd):
|
def cpu_state(sd):
|
||||||
# detach & move to host without locking the compute stream
|
# detach & move to host without locking the compute stream
|
||||||
return {k: v.detach().to("cpu", non_blocking=True) for k, v in sd.items()}
|
return {k: v.detach().to("cpu", non_blocking=True) for k, v in sd.items()}
|
||||||
|
@ -95,7 +95,7 @@ class BaseArgs:
|
|||||||
obs_normalization: bool = True
|
obs_normalization: bool = True
|
||||||
"""whether to enable observation normalization"""
|
"""whether to enable observation normalization"""
|
||||||
reward_normalization: bool = False
|
reward_normalization: bool = False
|
||||||
"""whether to enable reward normalization (Not recommended for now, it's unstable.)"""
|
"""whether to enable reward normalization"""
|
||||||
max_grad_norm: float = 0.0
|
max_grad_norm: float = 0.0
|
||||||
"""the maximum gradient norm"""
|
"""the maximum gradient norm"""
|
||||||
amp: bool = True
|
amp: bool = True
|
||||||
@ -113,6 +113,8 @@ class BaseArgs:
|
|||||||
"""(Playground only) Use tuned reward for G1"""
|
"""(Playground only) Use tuned reward for G1"""
|
||||||
action_bounds: float = 1.0
|
action_bounds: float = 1.0
|
||||||
"""(IsaacLab only) the bounds of the action space (-action_bounds, action_bounds)"""
|
"""(IsaacLab only) the bounds of the action space (-action_bounds, action_bounds)"""
|
||||||
|
task_embedding_dim: int = 32
|
||||||
|
"""the dimension of the task embedding"""
|
||||||
|
|
||||||
weight_decay: float = 0.1
|
weight_decay: float = 0.1
|
||||||
"""the weight decay of the optimizer"""
|
"""the weight decay of the optimizer"""
|
||||||
@ -169,6 +171,9 @@ def get_args():
|
|||||||
"Isaac-Velocity-Rough-G1-v0": IsaacVelocityRoughG1Args,
|
"Isaac-Velocity-Rough-G1-v0": IsaacVelocityRoughG1Args,
|
||||||
"Isaac-Repose-Cube-Allegro-Direct-v0": IsaacReposeCubeAllegroDirectArgs,
|
"Isaac-Repose-Cube-Allegro-Direct-v0": IsaacReposeCubeAllegroDirectArgs,
|
||||||
"Isaac-Repose-Cube-Shadow-Direct-v0": IsaacReposeCubeShadowDirectArgs,
|
"Isaac-Repose-Cube-Shadow-Direct-v0": IsaacReposeCubeShadowDirectArgs,
|
||||||
|
# MTBench
|
||||||
|
"MTBench-meta-world-v2-mt10": MetaWorldMT10Args,
|
||||||
|
"MTBench-meta-world-v2-mt50": MetaWorldMT50Args,
|
||||||
}
|
}
|
||||||
# If the provided env_name has a specific Args class, use it
|
# If the provided env_name has a specific Args class, use it
|
||||||
if base_args.env_name in env_to_args_class:
|
if base_args.env_name in env_to_args_class:
|
||||||
@ -183,6 +188,9 @@ def get_args():
|
|||||||
elif base_args.env_name.startswith("Isaac-"):
|
elif base_args.env_name.startswith("Isaac-"):
|
||||||
# IsaacLab
|
# IsaacLab
|
||||||
specific_args = tyro.cli(IsaacLabArgs)
|
specific_args = tyro.cli(IsaacLabArgs)
|
||||||
|
elif base_args.env_name.startswith("MTBench-"):
|
||||||
|
# MTBench
|
||||||
|
specific_args = tyro.cli(MTBenchArgs)
|
||||||
else:
|
else:
|
||||||
# MuJoCo Playground
|
# MuJoCo Playground
|
||||||
specific_args = tyro.cli(MuJoCoPlaygroundArgs)
|
specific_args = tyro.cli(MuJoCoPlaygroundArgs)
|
||||||
@ -280,6 +288,38 @@ class MuJoCoPlaygroundArgs(BaseArgs):
|
|||||||
gamma: float = 0.97
|
gamma: float = 0.97
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MTBenchArgs(BaseArgs):
|
||||||
|
# Default hyperparameters for MTBench
|
||||||
|
reward_normalization: bool = True
|
||||||
|
v_min: float = -10.0
|
||||||
|
v_max: float = 10.0
|
||||||
|
buffer_size: int = 2048 # 2K is usually enough for MTBench
|
||||||
|
num_envs: int = 4096
|
||||||
|
num_eval_envs: int = 4096
|
||||||
|
gamma: float = 0.99
|
||||||
|
num_steps: int = 8
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetaWorldMT10Args(MTBenchArgs):
|
||||||
|
# This config achieves 97 ~ 98% success rate within 10k steps (15-20 mins on A100)
|
||||||
|
env_name: str = "MTBench-meta-world-v2-mt10"
|
||||||
|
num_envs: int = 4096
|
||||||
|
num_eval_envs: int = 4096
|
||||||
|
num_steps: int = 8
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetaWorldMT50Args(MTBenchArgs):
|
||||||
|
# FastTD3 + SimbaV2 achieves >90% success rate within 20k steps (80 mins on A100)
|
||||||
|
# Performance further improves with more training steps, slowly.
|
||||||
|
env_name: str = "MTBench-meta-world-v2-mt50"
|
||||||
|
num_envs: int = 8192
|
||||||
|
num_eval_envs: int = 8192
|
||||||
|
num_steps: int = 8
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class G1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
|
class G1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
|
||||||
env_name: str = "G1JoystickFlatTerrain"
|
env_name: str = "G1JoystickFlatTerrain"
|
||||||
|
@ -18,17 +18,24 @@ import tqdm
|
|||||||
import wandb
|
import wandb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Required for avoiding IsaacGym import error
|
||||||
|
import isaacgym
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.amp import autocast, GradScaler
|
from torch.amp import autocast, GradScaler
|
||||||
|
|
||||||
from tensordict import TensorDict, from_module
|
from tensordict import TensorDict
|
||||||
|
|
||||||
from fast_td3_utils import (
|
from fast_td3_utils import (
|
||||||
EmpiricalNormalization,
|
EmpiricalNormalization,
|
||||||
RewardNormalizer,
|
RewardNormalizer,
|
||||||
|
PerTaskRewardNormalizer,
|
||||||
SimpleReplayBuffer,
|
SimpleReplayBuffer,
|
||||||
save_params,
|
save_params,
|
||||||
)
|
)
|
||||||
@ -103,6 +110,14 @@ def main():
|
|||||||
)
|
)
|
||||||
eval_envs = envs
|
eval_envs = envs
|
||||||
render_env = envs
|
render_env = envs
|
||||||
|
elif args.env_name.startswith("MTBench-"):
|
||||||
|
from environments.mtbench_env import MTBenchEnv
|
||||||
|
|
||||||
|
env_name = "-".join(args.env_name.split("-")[1:])
|
||||||
|
env_type = "mtbench"
|
||||||
|
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
|
||||||
|
eval_envs = envs
|
||||||
|
render_env = envs
|
||||||
else:
|
else:
|
||||||
from environments.mujoco_playground_env import make_env
|
from environments.mujoco_playground_env import make_env
|
||||||
|
|
||||||
@ -141,9 +156,19 @@ def main():
|
|||||||
critic_obs_normalizer = nn.Identity()
|
critic_obs_normalizer = nn.Identity()
|
||||||
|
|
||||||
if args.reward_normalization:
|
if args.reward_normalization:
|
||||||
reward_normalizer = RewardNormalizer(
|
if env_type in ["mtbench"]:
|
||||||
gamma=args.gamma, device=device, g_max=min(abs(args.v_min), abs(args.v_max))
|
reward_normalizer = PerTaskRewardNormalizer(
|
||||||
)
|
num_tasks=envs.num_tasks,
|
||||||
|
gamma=args.gamma,
|
||||||
|
device=device,
|
||||||
|
g_max=min(abs(args.v_min), abs(args.v_max)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reward_normalizer = RewardNormalizer(
|
||||||
|
gamma=args.gamma,
|
||||||
|
device=device,
|
||||||
|
g_max=min(abs(args.v_min), abs(args.v_max)),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
reward_normalizer = nn.Identity()
|
reward_normalizer = nn.Identity()
|
||||||
|
|
||||||
@ -165,12 +190,37 @@ def main():
|
|||||||
"device": device,
|
"device": device,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if env_type == "mtbench":
|
||||||
|
actor_kwargs["n_obs"] = n_obs - envs.num_tasks + args.task_embedding_dim
|
||||||
|
critic_kwargs["n_obs"] = n_critic_obs - envs.num_tasks + args.task_embedding_dim
|
||||||
|
actor_kwargs["num_tasks"] = envs.num_tasks
|
||||||
|
actor_kwargs["task_embedding_dim"] = args.task_embedding_dim
|
||||||
|
critic_kwargs["num_tasks"] = envs.num_tasks
|
||||||
|
critic_kwargs["task_embedding_dim"] = args.task_embedding_dim
|
||||||
|
|
||||||
if args.agent == "fasttd3":
|
if args.agent == "fasttd3":
|
||||||
from fast_td3 import Actor, Critic
|
if env_type in ["mtbench"]:
|
||||||
|
from fast_td3 import MultiTaskActor, MultiTaskCritic
|
||||||
|
|
||||||
|
actor_cls = MultiTaskActor
|
||||||
|
critic_cls = MultiTaskCritic
|
||||||
|
else:
|
||||||
|
from fast_td3 import Actor, Critic
|
||||||
|
|
||||||
|
actor_cls = Actor
|
||||||
|
critic_cls = Critic
|
||||||
|
|
||||||
print("Using FastTD3")
|
print("Using FastTD3")
|
||||||
elif args.agent == "fasttd3_simbav2":
|
elif args.agent == "fasttd3_simbav2":
|
||||||
from fast_td3_simbav2 import Actor, Critic
|
if env_type in ["mtbench"]:
|
||||||
|
from fast_td3_simbav2 import MultiTaskActor, MultiTaskCritic
|
||||||
|
|
||||||
|
actor_cls = MultiTaskActor
|
||||||
|
critic_cls = MultiTaskCritic
|
||||||
|
else:
|
||||||
|
from fast_td3_simbav2 import Actor, Critic
|
||||||
|
|
||||||
|
actor_cls = Actor
|
||||||
|
|
||||||
print("Using FastTD3 + SimbaV2")
|
print("Using FastTD3 + SimbaV2")
|
||||||
actor_kwargs.pop("init_scale")
|
actor_kwargs.pop("init_scale")
|
||||||
@ -199,25 +249,31 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Agent {args.agent} not supported")
|
raise ValueError(f"Agent {args.agent} not supported")
|
||||||
|
|
||||||
actor = Actor(**actor_kwargs)
|
actor = actor_cls(**actor_kwargs)
|
||||||
actor_detach = Actor(**actor_kwargs)
|
|
||||||
|
|
||||||
# Copy params to actor_detach without grad
|
if env_type in ["mtbench"]:
|
||||||
from_module(actor).data.to_module(actor_detach)
|
# Python 3.8 doesn't support 'from_module' in tensordict
|
||||||
policy = actor_detach.explore
|
policy = actor.explore
|
||||||
|
else:
|
||||||
|
from tensordict import from_module
|
||||||
|
|
||||||
qnet = Critic(**critic_kwargs)
|
actor_detach = actor_cls(**actor_kwargs)
|
||||||
qnet_target = Critic(**critic_kwargs)
|
# Copy params to actor_detach without grad
|
||||||
|
from_module(actor).data.to_module(actor_detach)
|
||||||
|
policy = actor_detach.explore
|
||||||
|
|
||||||
|
qnet = critic_cls(**critic_kwargs)
|
||||||
|
qnet_target = critic_cls(**critic_kwargs)
|
||||||
qnet_target.load_state_dict(qnet.state_dict())
|
qnet_target.load_state_dict(qnet.state_dict())
|
||||||
|
|
||||||
q_optimizer = optim.AdamW(
|
q_optimizer = optim.AdamW(
|
||||||
list(qnet.parameters()),
|
list(qnet.parameters()),
|
||||||
lr=args.critic_learning_rate,
|
lr=torch.tensor(args.critic_learning_rate, device=device),
|
||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
)
|
)
|
||||||
actor_optimizer = optim.AdamW(
|
actor_optimizer = optim.AdamW(
|
||||||
list(actor.parameters()),
|
list(actor.parameters()),
|
||||||
lr=args.actor_learning_rate,
|
lr=torch.tensor(args.actor_learning_rate, device=device),
|
||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -225,12 +281,12 @@ def main():
|
|||||||
q_scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
q_scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||||
q_optimizer,
|
q_optimizer,
|
||||||
T_max=args.total_timesteps,
|
T_max=args.total_timesteps,
|
||||||
eta_min=args.critic_learning_rate_end, # Decay to 10% of initial lr
|
eta_min=torch.tensor(args.critic_learning_rate_end, device=device),
|
||||||
)
|
)
|
||||||
actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||||
actor_optimizer,
|
actor_optimizer,
|
||||||
T_max=args.total_timesteps,
|
T_max=args.total_timesteps,
|
||||||
eta_min=args.actor_learning_rate_end, # Decay to 10% of initial lr
|
eta_min=torch.tensor(args.actor_learning_rate_end, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
rb = SimpleReplayBuffer(
|
rb = SimpleReplayBuffer(
|
||||||
@ -262,20 +318,28 @@ def main():
|
|||||||
obs = eval_envs.reset()
|
obs = eval_envs.reset()
|
||||||
|
|
||||||
# Run for a fixed number of steps
|
# Run for a fixed number of steps
|
||||||
for _ in range(eval_envs.max_episode_steps):
|
for i in range(eval_envs.max_episode_steps):
|
||||||
with torch.no_grad(), autocast(
|
with torch.no_grad(), autocast(
|
||||||
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
|
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
|
||||||
):
|
):
|
||||||
obs = normalize_obs(obs)
|
obs = normalize_obs(obs)
|
||||||
actions = actor(obs)
|
actions = actor(obs)
|
||||||
|
|
||||||
next_obs, rewards, dones, _ = eval_envs.step(actions.float())
|
next_obs, rewards, dones, infos = eval_envs.step(actions.float())
|
||||||
|
|
||||||
|
if env_type == "mtbench":
|
||||||
|
# We only report success rate in MTBench evaluation
|
||||||
|
rewards = (
|
||||||
|
infos["episode"]["success"].float() if "episode" in infos else 0.0
|
||||||
|
)
|
||||||
episode_returns = torch.where(
|
episode_returns = torch.where(
|
||||||
~done_masks, episode_returns + rewards, episode_returns
|
~done_masks, episode_returns + rewards, episode_returns
|
||||||
)
|
)
|
||||||
episode_lengths = torch.where(
|
episode_lengths = torch.where(
|
||||||
~done_masks, episode_lengths + 1, episode_lengths
|
~done_masks, episode_lengths + 1, episode_lengths
|
||||||
)
|
)
|
||||||
|
if env_type == "mtbench" and "episode" in infos:
|
||||||
|
dones = dones | infos["episode"]["success"]
|
||||||
done_masks = torch.logical_or(done_masks, dones)
|
done_masks = torch.logical_or(done_masks, dones)
|
||||||
if done_masks.all():
|
if done_masks.all():
|
||||||
break
|
break
|
||||||
@ -291,9 +355,9 @@ def main():
|
|||||||
if env_type == "humanoid_bench":
|
if env_type == "humanoid_bench":
|
||||||
obs = render_env.reset()
|
obs = render_env.reset()
|
||||||
renders = [render_env.render()]
|
renders = [render_env.render()]
|
||||||
elif env_type == "isaaclab":
|
elif env_type in ["isaaclab", "mtbench"]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"We don't support rendering for IsaacLab environments"
|
"We don't support rendering for IsaacLab and MTBench environments"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
obs = render_env.reset()
|
obs = render_env.reset()
|
||||||
@ -399,6 +463,7 @@ def main():
|
|||||||
)
|
)
|
||||||
scaler.step(q_optimizer)
|
scaler.step(q_optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
q_scheduler.step()
|
||||||
|
|
||||||
logs_dict["critic_grad_norm"] = critic_grad_norm.detach()
|
logs_dict["critic_grad_norm"] = critic_grad_norm.detach()
|
||||||
logs_dict["qf_loss"] = qf_loss.detach()
|
logs_dict["qf_loss"] = qf_loss.detach()
|
||||||
@ -434,6 +499,7 @@ def main():
|
|||||||
)
|
)
|
||||||
scaler.step(actor_optimizer)
|
scaler.step(actor_optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
actor_scheduler.step()
|
||||||
logs_dict["actor_grad_norm"] = actor_grad_norm.detach()
|
logs_dict["actor_grad_norm"] = actor_grad_norm.detach()
|
||||||
logs_dict["actor_loss"] = actor_loss.detach()
|
logs_dict["actor_loss"] = actor_loss.detach()
|
||||||
return logs_dict
|
return logs_dict
|
||||||
@ -500,7 +566,12 @@ def main():
|
|||||||
truncations = infos["time_outs"]
|
truncations = infos["time_outs"]
|
||||||
|
|
||||||
if args.reward_normalization:
|
if args.reward_normalization:
|
||||||
update_stats(rewards, dones.float())
|
if env_type == "mtbench":
|
||||||
|
task_ids_one_hot = obs[..., -envs.num_tasks :]
|
||||||
|
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||||
|
update_stats(rewards, dones.float(), task_ids=task_indices)
|
||||||
|
else:
|
||||||
|
update_stats(rewards, dones.float())
|
||||||
|
|
||||||
if envs.asymmetric_obs:
|
if envs.asymmetric_obs:
|
||||||
next_critic_obs = infos["observations"]["critic"]
|
next_critic_obs = infos["observations"]["critic"]
|
||||||
@ -550,7 +621,15 @@ def main():
|
|||||||
data["next"]["observations"]
|
data["next"]["observations"]
|
||||||
)
|
)
|
||||||
raw_rewards = data["next"]["rewards"]
|
raw_rewards = data["next"]["rewards"]
|
||||||
data["next"]["rewards"] = normalize_reward(raw_rewards)
|
if env_type in ["mtbench"] and args.reward_normalization:
|
||||||
|
# Multi-task reward normalization
|
||||||
|
task_ids_one_hot = data["observations"][..., -envs.num_tasks :]
|
||||||
|
task_indices = torch.argmax(task_ids_one_hot, dim=1)
|
||||||
|
data["next"]["rewards"] = normalize_reward(
|
||||||
|
raw_rewards, task_ids=task_indices
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data["next"]["rewards"] = normalize_reward(raw_rewards)
|
||||||
if envs.asymmetric_obs:
|
if envs.asymmetric_obs:
|
||||||
data["critic_observations"] = normalize_critic_obs(
|
data["critic_observations"] = normalize_critic_obs(
|
||||||
data["critic_observations"]
|
data["critic_observations"]
|
||||||
@ -591,7 +670,7 @@ def main():
|
|||||||
if args.eval_interval > 0 and global_step % args.eval_interval == 0:
|
if args.eval_interval > 0 and global_step % args.eval_interval == 0:
|
||||||
print(f"Evaluating at global step {global_step}")
|
print(f"Evaluating at global step {global_step}")
|
||||||
eval_avg_return, eval_avg_length = evaluate()
|
eval_avg_return, eval_avg_length = evaluate()
|
||||||
if env_type in ["humanoid_bench", "isaaclab"]:
|
if env_type in ["humanoid_bench", "isaaclab", "mtbench"]:
|
||||||
# NOTE: Hacky way of evaluating performance, but just works
|
# NOTE: Hacky way of evaluating performance, but just works
|
||||||
obs = envs.reset()
|
obs = envs.reset()
|
||||||
logs["eval_avg_return"] = eval_avg_return
|
logs["eval_avg_return"] = eval_avg_return
|
||||||
@ -644,10 +723,6 @@ def main():
|
|||||||
f"models/{run_name}_{global_step}.pt",
|
f"models/{run_name}_{global_step}.pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update learning rates
|
|
||||||
q_scheduler.step()
|
|
||||||
actor_scheduler.step()
|
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
|
15
requirements/requirements_isaacgym.txt
Normal file
15
requirements/requirements_isaacgym.txt
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
gymnasium<1.0.0
|
||||||
|
jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11"
|
||||||
|
matplotlib
|
||||||
|
moviepy
|
||||||
|
numpy<2.0
|
||||||
|
pandas
|
||||||
|
protobuf
|
||||||
|
pygame
|
||||||
|
stable-baselines3
|
||||||
|
tqdm
|
||||||
|
wandb
|
||||||
|
torchrl==0.5.0
|
||||||
|
tensordict==0.5.0
|
||||||
|
tyro
|
||||||
|
loguru
|
Loading…
Reference in New Issue
Block a user