Initial Public Release

This commit is contained in:
Younggyo Seo 2025-05-29 01:49:23 +00:00
commit 258bfe67dd
18 changed files with 3608 additions and 0 deletions

9
.gitignore vendored Normal file
View File

@ -0,0 +1,9 @@
media
models
wandb
figures
visualize.ipynb
record.ipynb
*.pyc
.ipynb_checkpoints
fast_td3.egg-info/

5
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,5 @@
repos:
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black

315
LICENSE Normal file
View File

@ -0,0 +1,315 @@
MIT License
Copyright (c) 2025 Younggyo Seo
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
--------------------------------------------------------------------------------
MIT License
Copyright (c) 2024 LeanRL developers
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
--------------------------------------------------------------------------------
Code in `cleanrl/ddpg_continuous_action.py` and `cleanrl/td3_continuous_action.py` are adapted from https://github.com/sfujim/TD3
MIT License
Copyright (c) 2020 Scott Fujimoto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
--------------------------------------------------------------------------------
Code in `cleanrl/sac_continuous_action.py` is inspired and adapted from [haarnoja/sac](https://github.com/haarnoja/sac), [openai/spinningup](https://github.com/openai/spinningup), [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic), [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3), and [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac).
- [haarnoja/sac](https://github.com/haarnoja/sac/blob/8258e33633c7e37833cc39315891e77adfbe14b2/LICENSE.txt)
COPYRIGHT
All contributions by the University of California:
Copyright (c) 2017, 2018 The Regents of the University of California (Regents)
All rights reserved.
All other contributions:
Copyright (c) 2017, 2018, the respective contributors
All rights reserved.
SAC uses a shared copyright model: each contributor holds copyright over
their contributions to the SAC codebase. The project versioning records all such
contribution and copyright details. If a contributor wants to further mark
their specific copyright on a particular contribution, they should indicate
their copyright solely in the commit message of the change when it is
committed.
LICENSE
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
CONTRIBUTION AGREEMENT
By contributing to the SAC repository through pull-request, comment,
or otherwise, the contributor releases their content to the
license and copyright terms herein.
- [openai/spinningup](https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/LICENSE)
The MIT License
Copyright (c) 2018 OpenAI (http://openai.com)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
- [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3/blob/44e53ff8115e8f4bff1d5218f10c8c7d1a4cfc12/LICENSE)
The MIT License
Copyright (c) 2019 Antonin Raffin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
- [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac/blob/81c5b536d3a1c5616b2531e446450df412a064fb/LICENSE)
MIT License
Copyright (c) 2019 Denis Yarats
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
- [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic/blob/master/LICENSE)
MIT License
Copyright (c) 2018 Pranjal Tandon
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---------------------------------------------------------------------------------
The CONTRIBUTING.md is adopted from https://github.com/entity-neural-network/incubator/blob/2a0c38b30828df78c47b0318c76a4905020618dd/CONTRIBUTING.md
and https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/CONTRIBUTING.md
MIT License
Copyright (c) 2021 Entity Neural Network developers
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
MIT License
Copyright (c) 2020 Stable-Baselines Team
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---------------------------------------------------------------------------------
The cleanrl/ppo_continuous_action_isaacgym.py is contributed by Nvidia
SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: MIT
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
Code in `cleanrl/qdagger_dqn_atari_impalacnn.py` and `cleanrl/qdagger_dqn_atari_jax_impalacnn.py` are adapted from https://github.com/google-research/reincarnating_rl
**NOTE: the original repo did not fill out the copyright section in their license
so the following copyright notice is copied as is per the license requirement.
See https://github.com/google-research/reincarnating_rl/blob/a1d402f48a9f8658ca6aa0ddf416ab391745ff2c/LICENSE#L189
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

238
README.md Normal file
View File

@ -0,0 +1,238 @@
# FastTD3 - Simple and Fast RL for Humanoid Control
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![arXiv](https://img.shields.io/badge/arXiv-2505.22642-b31b1b.svg)](https://arxiv.org/abs/2505.22642)
FastTD3 is a high-performance variant of the Twin Delayed Deep Deterministic Policy Gradient (TD3) algorithm, optimized for complex humanoid control tasks. FastTD3 can solve various humanoid control tasks with dexterous hands from HumanoidBench in just a few hours of training. Furthermore, FastTD3 achieves similar or better wall-time-efficiency to PPO in high-dimensional control tasks from popular simulations such as IsaacLab and MuJoCo Playground.
For more information, please see our [project webpage](https://younggyo.me/fast_td3)
## ✨ Features
FastTD3 offers researchers a significant speedup in training complex humanoid agents.
- Ready-to-go codebase with detailed instructions and pre-configured hyperparameters for each task
- Support popular benchmarks: HumanoidBench, MuJoCo Playground, and IsaacLab
- User-friendly features that can accelerate your research, such as rendering rollouts, torch optimizations (AMP and compile), and saving and loading checkpoints
## ⚙️ Prerequisites
Before you begin, ensure you have the following installed:
- Conda (for environment management)
- Git LFS (Large File Storage) -- For IsaacLab
- CMake -- For IsaacLab
And the following system packages:
```bash
sudo apt install libglfw3 libgl1-mesa-glx libosmesa6 git-lfs cmake
```
## 📖 Installation
This project requires different Conda environments for different sets of experiments.
### Common Setup
First, ensure the common dependencies are installed as mentioned in the [Prerequisites](#prerequisites) section.
### Environment for HumanoidBench
```bash
conda create -n fasttd3_hb -y python=3.10
conda activate fasttd3_hb
pip install --editable git+https://github.com/carlosferrazza/humanoid-bench.git#egg=humanoid-bench
pip install -r requirements/requirements.txt
```
### Environment for MuJoCo Playground
```bash
conda create -n fasttd3_playground -y python=3.10
conda activate fasttd3_playground
pip install -r requirements/requirements_playground.txt
```
**⚠️ Note:** Our `requirements_playground.txt` specifies `Jax==0.4.35`, which we found to be stable for latest GPUs in certain tasks such as `LeapCubeReorient` or `LeapCubeRotateZAxis`
**⚠️ Note:** Current FastTD3 codebase uses customized MuJoCo Playground that supports saving last observations into info dictionary. We will work on incorporating this change into official repository hopefully soon.
### Environment for IsaacLab
```bash
conda create -n fasttd3_isaaclab -y python=3.10
conda activate fasttd3_isaaclab
# Install IsaacLab (refer to official documentation for the latest steps)
# Official Quickstart: https://isaac-sim.github.io/IsaacLab/main/source/setup/quickstart.html
pip install 'isaacsim[all,extscache]==4.5.0' --extra-index-url https://pypi.nvidia.com
git clone https://github.com/isaac-sim/IsaacLab.git
cd IsaacLab
./isaaclab.sh --install
cd ..
# Install project-specific requirements
pip install -r requirements/requirements.txt
```
### (Optional) Accelerate headless GPU rendering in cloud instances
In some cloud VM images the NVIDIA kernel driver is present but the user-space OpenGL/EGL/Vulkan libraries aren't, so MuJoCo falls back to CPU renderer. You can install just the NVIDIA user-space libraries (and skip rebuilding the kernel module) with:
```bash
sudo apt install -y kmod
sudo sh NVIDIA-Linux-x86_64-<your_driver_version>.run -s --no-kernel-module --ui=none --no-questions
```
As a rule-of-thumb, if you're running experiments and rendering is taking longer than 5 seconds, it is very likely that GPU renderer is not used.
## 🚀 Running Experiments
Activate the appropriate Conda environment before running experiments.
Please see `fast_td3/hyperparams.py` for information regarding hyperparameters!
### HumanoidBench Experiments
```bash
conda activate fasttd3_hb
python fast_td3/train.py --env_name h1hand-hurdle-v0 --exp_name FastTD3 --render_interval 5000 --seed 1
```
### MuJoCo Playground Experiments
```bash
conda activate fasttd3_playground
python fast_td3/train.py --env_name G1JoystickRoughTerrain --exp_name FastTD3 --render_interval 5000 --seed 1
```
### IsaacLab Experiments
```bash
conda activate fasttd3_isaaclab
python fast_td3/train.py --env_name Isaac-Velocity-Flat-G1-v0 --exp_name FastTD3 --render_interval 0 --seed 1
python fast_td3/train.py --env_name Isaac-Repose-Cube-Allegro-Direct-v0 --exp_name FastTD3 --render_interval 0 --seed 1
```
**Quick note:** For boolean-based arguments, you can set them to False by adding `no_` in front each argument, for instance, if you want to disable Clipped Q Learning, you can specify `--no_use_cdq` in your command.
## 💡 Performance-Related Tips
We used a single Nvidia A100 80GB GPU for all experiments. Here are some remarks and tips for improving performances in your setup or troubleshooting in your machine configurations.
- *Sample-efficiency* tends to improve with larger `num_envs`, `num_updates`, and `batch_size`. But this comes at the cost of *Time-efficiency*. Our default settings are optimized for wall-time efficiency on a single A100 80GB GPU. If you're using a different setup, consider tuning hyperparameters accordingly.
- When FastTD3 performance is stuck at local minima at the early phase of training in your experiments
- First consider increasing the `num_updates`. This happens usually when the agent fails to exploit value functions. We also find higher `num_updates` tends to be helpful for relatively easier tasks or tasks with low-dimensional action spaces.
- If the agent is completely stuck or much worse than your expectation, try using `num_steps=3` or disabling `use_cdq`.
- For tasks that have penalty reward terms (e.g., torques, energy, action_rate, ..), consider lowering them for initial experiments, and tune the values. In some cases, curriculum learning with lower penalty terms followed by fine-tuning with stronger terms is effective.
- When you encounter out-of-memory error with your GPU, our recommendation for reducing GPU usage is (i) smaller `buffer_size`, (ii) smaller `batch_size`, and then (iii) smaller `num_envs`. Because our codebase is assigning the whole replay buffer in GPU to reduce CPU-GPU transfer bottleneck, it usually has the largest GPU consumption, but usually less harmful to reduce.
## 🛝 Playing with the FastTD3 training
A Jupyter notebook (`training_notebook.ipynb`) is available to help you get started with:
- Training FastTD3 agents.
- Loading pre-trained models.
- Visualizing agent behavior.
- Potentially, re-training or fine-tuning models.
## 🤖 Sim-to-Real RL with FastTD3
We provide the [walkthrough](sim2real.md) for training deployable policies with FastTD3.
## Contributing
We welcome contributions! Please feel free to submit issues and pull requests.
## License
This project is licensed under the MIT License -- see the [LICENSE](LICENSE) file for details. Note that the repository relies on third-party libraries subject to their respective licenses.
## Acknowledgements
This codebase builds upon [LeanRL](https://github.com/pytorch-labs/LeanRL) framework.
We would like to thank people who have helped throughout the project:
- We thank [Kevin Zakka](https://kzakka.com/) for the help in setting up MuJoCo Playground.
- We thank [Changyeon Kim](https://changyeon.site/) for testing the early version of this codebase
## Citations
### FastTD3
```bibtex
@article{seo2025fasttd3,
title = {Generative Image as Action Models},
author = {Seo, Younggyo and Sferrazza, Carmelo and Geng, Haoran and Nauman, Michal and Yin, Zhao-Heng and Abbeel, Pieter},
booktitle = {preprint},
year = {2025},
}
```
### TD3
```bibtex
@inproceedings{fujimoto2018addressing,
title={Addressing function approximation error in actor-critic methods},
author={Fujimoto, Scott and Hoof, Herke and Meger, David},
booktitle={International conference on machine learning},
pages={1587--1596},
year={2018},
organization={PMLR}
}
```
### LeanRL
Following the [LeanRL](https://github.com/pytorch-labs/LeanRL)'s recommendation, we put CleanRL's bibtex here:
```bibtex
@article{huang2022cleanrl,
author = {Shengyi Huang and Rousslan Fernand Julien Dossa and Chang Ye and Jeff Braga and Dipam Chakraborty and Kinal Mehta and João G.M. Araújo},
title = {CleanRL: High-quality Single-file Implementations of Deep Reinforcement Learning Algorithms},
journal = {Journal of Machine Learning Research},
year = {2022},
volume = {23},
number = {274},
pages = {1--18},
url = {http://jmlr.org/papers/v23/21-1342.html}
}
```
### Parallel Q-Learning (PQL)
```bibtex
@inproceedings{li2023parallel,
title={Parallel $ Q $-Learning: Scaling Off-policy Reinforcement Learning under Massively Parallel Simulation},
author={Li, Zechu and Chen, Tao and Hong, Zhang-Wei and Ajay, Anurag and Agrawal, Pulkit},
booktitle={International Conference on Machine Learning},
pages={19440--19459},
year={2023},
organization={PMLR}
}
```
### HumanoidBench
```bibtex
@inproceedings{sferrazza2024humanoidbench,
title={Humanoidbench: Simulated humanoid benchmark for whole-body locomotion and manipulation},
author={Sferrazza, Carmelo and Huang, Dun-Ming and Lin, Xingyu and Lee, Youngwoon and Abbeel, Pieter},
booktitle={Robotics: Science and Systems},
year={2024}
}
```
### MuJoCo Playground
```bibtex
@article{zakka2025mujoco,
title={MuJoCo Playground},
author={Zakka, Kevin and Tabanpour, Baruch and Liao, Qiayuan and Haiderbhai, Mustafa and Holt, Samuel and Luo, Jing Yuan and Allshire, Arthur and Frey, Erik and Sreenath, Koushil and Kahrs, Lueder A and others},
journal={arXiv preprint arXiv:2502.08844},
year={2025}
}
```
### IsaacLab
```bibtex
@article{mittal2023orbit,
author={Mittal, Mayank and Yu, Calvin and Yu, Qinxi and Liu, Jingzhou and Rudin, Nikita and Hoeller, David and Yuan, Jia Lin and Singh, Ritvik and Guo, Yunrong and Mazhar, Hammad and Mandlekar, Ajay and Babich, Buck and State, Gavriel and Hutter, Marco and Garg, Animesh},
journal={IEEE Robotics and Automation Letters},
title={Orbit: A Unified Simulation Framework for Interactive Robot Learning Environments},
year={2023},
volume={8},
number={6},
pages={3740-3747},
doi={10.1109/LRA.2023.3270034}
}
```

20
fast_td3/__init__.py Normal file
View File

@ -0,0 +1,20 @@
"""
Fast TD3 is a high-performance implementation of Twin Delayed Deep Deterministic Policy Gradient (TD3)
with distributional critics for reinforcement learning.
"""
# Core model components
from fast_td3.fast_td3 import Actor, Critic, DistributionalQNetwork
from fast_td3.fast_td3_utils import EmpiricalNormalization, SimpleReplayBuffer
from fast_td3.fast_td3_deploy import Policy, load_policy
__all__ = [
# Core model components
"Actor",
"Critic",
"DistributionalQNetwork",
"EmpiricalNormalization",
"SimpleReplayBuffer",
"Policy",
"load_policy",
]

View File

@ -0,0 +1,128 @@
from __future__ import annotations
import gymnasium as gym
import humanoid_bench
from gymnasium.wrappers import TimeLimit
from stable_baselines3.common.vec_env import SubprocVecEnv
import numpy as np
import torch
from loguru import logger as log
# Disable all logging below CRITICAL level
log.remove()
log.add(lambda msg: False, level="CRITICAL")
def make_env(env_name, rank, render_mode=None, seed=0):
"""
Utility function for multiprocessed env.
:param rank: (int) index of the subprocess
:param seed: (int) the inital seed for RNG
"""
if env_name in [
"h1hand-push-v0",
"h1-push-v0",
"h1hand-cube-v0",
"h1cube-v0",
"h1hand-basketball-v0",
"h1-basketball-v0",
"h1hand-kitchen-v0",
"h1-kitchen-v0",
]:
max_episode_steps = 500
else:
max_episode_steps = 1000
def _init():
import humanoid_bench
env = gym.make(env_name, render_mode=render_mode)
env = TimeLimit(env, max_episode_steps=max_episode_steps)
env.unwrapped.seed(seed + rank)
return env
return _init
class HumanoidBenchEnv:
"""Wraps HumanoidBench environment to support parallel environments."""
def __init__(self, env_name, num_envs=1, render_mode=None, device=None):
# NOTE: HumanoidBench action space is already normalized to [-1, 1]
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.sim_device = device
self.num_envs = num_envs
# Create the base environment
self.envs = SubprocVecEnv(
[make_env(env_name, i, render_mode=render_mode) for i in range(num_envs)]
)
if env_name in [
"h1hand-push-v0",
"h1-push-v0",
"h1hand-cube-v0",
"h1cube-v0",
"h1hand-basketball-v0",
"h1-basketball-v0",
"h1hand-kitchen-v0",
"h1-kitchen-v0",
]:
self.max_episode_steps = 500
else:
self.max_episode_steps = 1000
# For compatibility with MuJoCo Playground
self.asymmetric_obs = False # For comptatibility with MuJoCo Playground
self.num_obs = self.envs.observation_space.shape[-1]
self.num_actions = self.envs.action_space.shape[-1]
def reset(self):
"""Reset the environment."""
observations = self.envs.reset()
observations = torch.from_numpy(observations).to(
device=self.sim_device, dtype=torch.float
)
return observations
def render(self):
assert (
self.num_envs == 1
), "Currently only supports single environment rendering"
return self.envs.render()
def step(self, actions):
assert isinstance(actions, torch.Tensor)
actions = actions.cpu().numpy()
observations, rewards, dones, raw_infos = self.envs.step(actions)
# This will be used for getting 'true' next observations
infos = dict()
infos["observations"] = {"raw": {"obs": observations.copy()}}
truncateds = np.zeros_like(dones)
for i in range(self.num_envs):
if raw_infos[i].get("TimeLimit.truncated", False):
truncateds[i] = True
infos["observations"]["raw"]["obs"][i] = raw_infos[i][
"terminal_observation"
]
observations = torch.from_numpy(observations).to(
device=self.sim_device, dtype=torch.float
)
rewards = torch.from_numpy(rewards).to(
device=self.sim_device, dtype=torch.float
)
dones = torch.from_numpy(dones).to(device=self.sim_device)
truncateds = torch.from_numpy(truncateds).to(device=self.sim_device)
infos["observations"]["raw"]["obs"] = torch.from_numpy(
infos["observations"]["raw"]["obs"]
).to(device=self.sim_device, dtype=torch.float)
infos["time_outs"] = truncateds
return observations, rewards, dones, infos

View File

@ -0,0 +1,82 @@
from typing import Optional
import gymnasium as gym
import torch
from isaaclab.app import AppLauncher
app_launcher = AppLauncher(headless=True)
simulation_app = app_launcher.app
import isaaclab_tasks
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
class IsaacLabEnv:
"""Wrapper for IsaacLab environments to be compatible with MuJoCo Playground"""
def __init__(
self,
task_name: str,
device: str,
num_envs: int,
seed: int,
action_bounds: Optional[float] = None,
):
env_cfg = parse_env_cfg(
task_name,
device=device,
num_envs=num_envs,
)
env_cfg.seed = seed
self.seed = seed
self.envs = gym.make(task_name, cfg=env_cfg, render_mode=None)
self.num_envs = self.envs.unwrapped.num_envs
self.max_episode_steps = self.envs.unwrapped.max_episode_length
self.action_bounds = action_bounds
self.num_obs = self.envs.unwrapped.single_observation_space["policy"].shape[0]
self.asymmetric_obs = "critic" in self.envs.unwrapped.single_observation_space
if self.asymmetric_obs:
self.num_privileged_obs = self.envs.unwrapped.single_observation_space[
"critic"
].shape[0]
else:
self.num_privileged_obs = 0
self.num_actions = self.envs.unwrapped.single_action_space.shape[0]
def reset(self, random_start_init: bool = True) -> torch.Tensor:
obs_dict, _ = self.envs.reset()
# NOTE: decorrelate episode horizons like RSLRL
if random_start_init:
self.envs.unwrapped.episode_length_buf = torch.randint_like(
self.envs.unwrapped.episode_length_buf, high=int(self.max_episode_steps)
)
return obs_dict["policy"]
def reset_with_critic_obs(self) -> tuple[torch.Tensor, torch.Tensor]:
obs_dict, _ = self.envs.reset()
return obs_dict["policy"], obs_dict["critic"]
def step(
self, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
if self.action_bounds is not None:
actions = torch.clamp(actions, -1.0, 1.0) * self.action_bounds
obs_dict, rew, terminations, truncations, infos = self.envs.step(actions)
dones = (terminations | truncations).to(dtype=torch.long)
obs = obs_dict["policy"]
critic_obs = obs_dict["critic"] if self.asymmetric_obs else None
info_ret = {"time_outs": truncations, "observations": {"critic": critic_obs}}
# NOTE: There's really no way to get the raw observations from IsaacLab
# We just use the 'reset_obs' as next_obs, unfortunately.
# See https://github.com/isaac-sim/IsaacLab/issues/1362
info_ret["observations"]["raw"] = {
"obs": obs,
"critic_obs": critic_obs,
}
return obs, rew, dones, info_ret
def render(self):
raise NotImplementedError(
"We don't support rendering for IsaacLab environments"
)

View File

@ -0,0 +1,136 @@
from mujoco_playground import registry
from mujoco_playground import wrapper_torch
import jax
import mujoco
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
class PlaygroundEvalEnvWrapper:
def __init__(self, eval_env, max_episode_steps, env_name, num_eval_envs, seed):
"""
Wrapper used for evaluation / rendering environments.
Note that this is different from training environments that are
wrapped with RSLRLBraxWrapper.
"""
self.env = eval_env
self.env_name = env_name
self.num_envs = num_eval_envs
self.jit_reset = jax.jit(jax.vmap(self.env.reset))
self.jit_step = jax.jit(jax.vmap(self.env.step))
if isinstance(self.env.unwrapped.observation_size, dict):
self.asymmetric_obs = True
else:
self.asymmetric_obs = False
self.key = jax.random.PRNGKey(seed)
self.key_reset = jax.random.split(self.key, num_eval_envs)
self.max_episode_steps = max_episode_steps
def reset(self):
self.state = self.jit_reset(self.key_reset)
if self.asymmetric_obs:
obs = wrapper_torch._jax_to_torch(self.state.obs["state"])
else:
obs = wrapper_torch._jax_to_torch(self.state.obs)
return obs
def step(self, actions):
self.state = self.jit_step(self.state, wrapper_torch._torch_to_jax(actions))
if self.asymmetric_obs:
next_obs = wrapper_torch._jax_to_torch(self.state.obs["state"])
else:
next_obs = wrapper_torch._jax_to_torch(self.state.obs)
rewards = wrapper_torch._jax_to_torch(self.state.reward)
dones = wrapper_torch._jax_to_torch(self.state.done)
return next_obs, rewards, dones, None
def render_trajectory(self, trajectory):
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False
frames = self.env.render(
trajectory,
camera="track" if "Joystick" in self.env_name else None,
height=480,
width=640,
scene_option=scene_option,
)
return frames
def make_env(
env_name,
seed,
num_envs,
num_eval_envs,
device_rank,
use_tuned_reward=False,
use_domain_randomization=False,
use_push_randomization=False,
):
# Make training environment
train_env_cfg = registry.get_default_config(env_name)
if use_tuned_reward:
# NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper.
assert env_name in ["G1JoystickRoughTerrain", "G1JoystickFlatTerrain"]
train_env_cfg.reward_config.scales.energy = -5e-5
train_env_cfg.reward_config.scales.action_rate = -1e-1
train_env_cfg.reward_config.scales.torques = -1e-3
train_env_cfg.reward_config.scales.pose = -1.0
train_env_cfg.reward_config.scales.tracking_ang_vel = 1.25
train_env_cfg.reward_config.scales.tracking_lin_vel = 1.25
train_env_cfg.reward_config.scales.feet_phase = 1.0
train_env_cfg.reward_config.scales.ang_vel_xy = -0.3
train_env_cfg.reward_config.scales.orientation = -5.0
is_humanoid_task = env_name in [
"G1JoystickRoughTerrain",
"G1JoystickFlatTerrain",
"T1JoystickRoughTerrain",
"T1JoystickFlatTerrain",
]
if is_humanoid_task and not use_push_randomization:
train_env_cfg.push_config.enable = False
train_env_cfg.push_config.magnitude_range = [0.0, 0.0]
randomizer = (
registry.get_domain_randomizer(env_name) if use_domain_randomization else None
)
raw_env = registry.load(env_name, config=train_env_cfg)
train_env = wrapper_torch.RSLRLBraxWrapper(
raw_env,
num_envs,
seed,
train_env_cfg.episode_length,
train_env_cfg.action_repeat,
randomization_fn=randomizer,
device_rank=device_rank,
)
# Make evaluation environment
eval_env_cfg = registry.get_default_config(env_name)
if is_humanoid_task and not use_push_randomization:
eval_env_cfg.push_config.enable = False
eval_env_cfg.push_config.magnitude_range = [0.0, 0.0]
eval_env = registry.load(env_name, config=eval_env_cfg)
eval_env = PlaygroundEvalEnvWrapper(
eval_env, eval_env_cfg.episode_length, env_name, num_eval_envs, seed
)
render_env_cfg = registry.get_default_config(env_name)
if is_humanoid_task and not use_push_randomization:
render_env_cfg.push_config.enable = False
render_env_cfg.push_config.magnitude_range = [0.0, 0.0]
render_env = registry.load(env_name, config=render_env_cfg)
render_env = PlaygroundEvalEnvWrapper(
render_env, render_env_cfg.episode_length, env_name, 1, seed
)
return train_env, eval_env, render_env

217
fast_td3/fast_td3.py Normal file
View File

@ -0,0 +1,217 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistributionalQNetwork(nn.Module):
def __init__(
self,
n_obs: int,
n_act: int,
num_atoms: int,
v_min: float,
v_max: float,
hidden_dim: int,
device: torch.device = None,
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_obs + n_act, hidden_dim, device=device),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2, device=device),
nn.ReLU(),
nn.Linear(hidden_dim // 2, hidden_dim // 4, device=device),
nn.ReLU(),
nn.Linear(hidden_dim // 4, num_atoms, device=device),
)
self.v_min = v_min
self.v_max = v_max
self.num_atoms = num_atoms
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
x = torch.cat([obs, actions], 1)
x = self.net(x)
return x
def projection(
self,
obs: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
bootstrap: torch.Tensor,
gamma: float,
q_support: torch.Tensor,
device: torch.device,
) -> torch.Tensor:
delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
batch_size = rewards.shape[0]
target_z = rewards.unsqueeze(1) + bootstrap.unsqueeze(1) * gamma * q_support
target_z = target_z.clamp(self.v_min, self.v_max)
b = (target_z - self.v_min) / delta_z
l = torch.floor(b).long()
u = torch.ceil(b).long()
l_mask = torch.logical_and((u > 0), (l == u))
u_mask = torch.logical_and((l < (self.num_atoms - 1)), (l == u))
l = torch.where(l_mask, l - 1, l)
u = torch.where(u_mask, u + 1, u)
next_dist = F.softmax(self.forward(obs, actions), dim=1)
proj_dist = torch.zeros_like(next_dist)
offset = (
torch.linspace(
0, (batch_size - 1) * self.num_atoms, batch_size, device=device
)
.unsqueeze(1)
.expand(batch_size, self.num_atoms)
.long()
)
proj_dist.view(-1).index_add_(
0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
)
proj_dist.view(-1).index_add_(
0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
)
return proj_dist
class Critic(nn.Module):
def __init__(
self,
n_obs: int,
n_act: int,
num_atoms: int,
v_min: float,
v_max: float,
hidden_dim: int,
device: torch.device = None,
):
super().__init__()
self.qnet1 = DistributionalQNetwork(
n_obs=n_obs,
n_act=n_act,
num_atoms=num_atoms,
v_min=v_min,
v_max=v_max,
hidden_dim=hidden_dim,
device=device,
)
self.qnet2 = DistributionalQNetwork(
n_obs=n_obs,
n_act=n_act,
num_atoms=num_atoms,
v_min=v_min,
v_max=v_max,
hidden_dim=hidden_dim,
device=device,
)
self.register_buffer(
"q_support", torch.linspace(v_min, v_max, num_atoms, device=device)
)
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
return self.qnet1(obs, actions), self.qnet2(obs, actions)
def projection(
self,
obs: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
bootstrap: torch.Tensor,
gamma: float,
) -> torch.Tensor:
"""Projection operation that includes q_support directly"""
q1_proj = self.qnet1.projection(
obs,
actions,
rewards,
bootstrap,
gamma,
self.q_support,
self.q_support.device,
)
q2_proj = self.qnet2.projection(
obs,
actions,
rewards,
bootstrap,
gamma,
self.q_support,
self.q_support.device,
)
return q1_proj, q2_proj
def get_value(self, probs: torch.Tensor) -> torch.Tensor:
"""Calculate value from logits using support"""
return torch.sum(probs * self.q_support, dim=1)
class Actor(nn.Module):
def __init__(
self,
n_obs: int,
n_act: int,
num_envs: int,
init_scale: float,
hidden_dim: int,
std_min: float = 0.05,
std_max: float = 0.8,
device: torch.device = None,
):
super().__init__()
self.n_act = n_act
self.net = nn.Sequential(
nn.Linear(n_obs, hidden_dim, device=device),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2, device=device),
nn.ReLU(),
nn.Linear(hidden_dim // 2, hidden_dim // 4, device=device),
nn.ReLU(),
)
self.fc_mu = nn.Sequential(
nn.Linear(hidden_dim // 4, n_act, device=device),
nn.Tanh(),
)
nn.init.normal_(self.fc_mu[0].weight, 0.0, init_scale)
nn.init.constant_(self.fc_mu[0].bias, 0.0)
noise_scales = (
torch.rand(num_envs, 1, device=device) * (std_max - std_min) + std_min
)
self.register_buffer("noise_scales", noise_scales)
self.register_buffer("std_min", torch.as_tensor(std_min, device=device))
self.register_buffer("std_max", torch.as_tensor(std_max, device=device))
self.n_envs = num_envs
def forward(self, obs: torch.Tensor) -> torch.Tensor:
x = obs
x = self.net(x)
action = self.fc_mu(x)
return action
def explore(
self, obs: torch.Tensor, dones: torch.Tensor = None, deterministic: bool = False
) -> torch.Tensor:
# If dones is provided, resample noise for environments that are done
if dones is not None and dones.sum() > 0:
# Generate new noise scales for done environments (one per environment)
new_scales = (
torch.rand(self.n_envs, 1, device=obs.device)
* (self.std_max - self.std_min)
+ self.std_min
)
# Update only the noise scales for environments that are done
dones_view = dones.view(-1, 1) > 0
self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales)
act = self(obs)
if deterministic:
return act
noise = torch.randn_like(act) * self.noise_scales
return act + noise

View File

@ -0,0 +1,66 @@
import torch
import torch.nn as nn
from .fast_td3_utils import EmpiricalNormalization
from .fast_td3 import Actor
class Policy(nn.Module):
def __init__(
self,
n_obs: int,
n_act: int,
num_envs: int,
init_scale: float,
actor_hidden_dim: int,
):
super().__init__()
self.actor = Actor(
n_obs=n_obs,
n_act=n_act,
num_envs=num_envs,
device="cpu",
init_scale=init_scale,
hidden_dim=actor_hidden_dim,
)
self.obs_normalizer = EmpiricalNormalization(shape=n_obs, device="cpu")
self.actor.eval()
self.obs_normalizer.eval()
@torch.no_grad
def forward(self, obs: torch.Tensor) -> torch.Tensor:
norm_obs = self.obs_normalizer(obs)
actions = self.actor(norm_obs)
return actions
@torch.no_grad
def act(self, obs: torch.Tensor) -> torch.distributions.Normal:
actions = self.forward(obs)
return torch.distributions.Normal(actions, torch.ones_like(actions) * 1e-8)
def load_policy(checkpoint_path):
torch_checkpoint = torch.load(
f"{checkpoint_path}", map_location="cpu", weights_only=False
)
args = torch_checkpoint["args"]
n_obs = torch_checkpoint["actor_state_dict"]["net.0.weight"].shape[-1]
n_act = torch_checkpoint["actor_state_dict"]["fc_mu.0.weight"].shape[0]
policy = Policy(
n_obs=n_obs,
n_act=n_act,
num_envs=args["num_envs"],
init_scale=args["init_scale"],
actor_hidden_dim=args["actor_hidden_dim"],
)
policy.actor.load_state_dict(torch_checkpoint["actor_state_dict"])
if len(torch_checkpoint["obs_normalizer_state"]) == 0:
policy.obs_normalizer = nn.Identity()
else:
policy.obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"])
return policy

387
fast_td3/fast_td3_utils.py Normal file
View File

@ -0,0 +1,387 @@
import os
import torch
import torch.nn as nn
from tensordict import TensorDict
class SimpleReplayBuffer(nn.Module):
def __init__(
self,
n_env: int,
buffer_size: int,
n_obs: int,
n_act: int,
n_critic_obs: int,
asymmetric_obs: bool = False,
n_steps: int = 1,
gamma: float = 0.99,
device=None,
):
"""
A simple replay buffer that stores transitions in a circular buffer.
Supports n-step returns and asymmetric observations.
"""
super().__init__()
self.n_env = n_env
self.buffer_size = buffer_size
self.n_obs = n_obs
self.n_act = n_act
self.n_critic_obs = n_critic_obs
self.asymmetric_obs = asymmetric_obs
self.gamma = gamma
self.n_steps = n_steps
self.device = device
self.observations = torch.zeros(
(n_env, buffer_size, n_obs), device=device, dtype=torch.float
)
self.actions = torch.zeros(
(n_env, buffer_size, n_act), device=device, dtype=torch.float
)
self.rewards = torch.zeros(
(n_env, buffer_size), device=device, dtype=torch.float
)
self.dones = torch.zeros((n_env, buffer_size), device=device, dtype=torch.long)
self.truncations = torch.zeros(
(n_env, buffer_size), device=device, dtype=torch.long
)
self.next_observations = torch.zeros(
(n_env, buffer_size, n_obs), device=device, dtype=torch.float
)
if asymmetric_obs:
self.critic_observations = torch.zeros(
(n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float
)
self.next_critic_observations = torch.zeros(
(n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float
)
self.ptr = 0
def extend(
self,
tensor_dict: TensorDict,
):
observations = tensor_dict["observations"]
actions = tensor_dict["actions"]
rewards = tensor_dict["next"]["rewards"]
dones = tensor_dict["next"]["dones"]
truncations = tensor_dict["next"]["truncations"]
next_observations = tensor_dict["next"]["observations"]
ptr = self.ptr % self.buffer_size
self.observations[:, ptr] = observations
self.actions[:, ptr] = actions
self.rewards[:, ptr] = rewards
self.dones[:, ptr] = dones
self.truncations[:, ptr] = truncations
self.next_observations[:, ptr] = next_observations
if self.asymmetric_obs:
critic_observations = tensor_dict["critic_observations"]
self.critic_observations[:, ptr] = critic_observations
next_critic_observations = tensor_dict["next"]["critic_observations"]
self.next_critic_observations[:, ptr] = next_critic_observations
self.ptr += 1
def sample(self, batch_size: int):
# we will sample n_env * batch_size transitions
if self.n_steps == 1:
indices = torch.randint(
0,
min(self.buffer_size, self.ptr),
(self.n_env, batch_size),
device=self.device,
)
obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act)
observations = torch.gather(self.observations, 1, obs_indices).reshape(
self.n_env * batch_size, self.n_obs
)
next_observations = torch.gather(
self.next_observations, 1, obs_indices
).reshape(self.n_env * batch_size, self.n_obs)
actions = torch.gather(self.actions, 1, act_indices).reshape(
self.n_env * batch_size, self.n_act
)
rewards = torch.gather(self.rewards, 1, indices).reshape(
self.n_env * batch_size
)
dones = torch.gather(self.dones, 1, indices).reshape(
self.n_env * batch_size
)
truncations = torch.gather(self.truncations, 1, indices).reshape(
self.n_env * batch_size
)
if self.asymmetric_obs:
critic_obs_indices = indices.unsqueeze(-1).expand(
-1, -1, self.n_critic_obs
)
critic_observations = torch.gather(
self.critic_observations, 1, critic_obs_indices
).reshape(self.n_env * batch_size, self.n_critic_obs)
next_critic_observations = torch.gather(
self.next_critic_observations, 1, critic_obs_indices
).reshape(self.n_env * batch_size, self.n_critic_obs)
else:
# Sample base indices
indices = torch.randint(
0,
min(self.buffer_size, self.ptr),
(self.n_env, batch_size),
device=self.device,
)
obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act)
# Get base transitions
observations = torch.gather(self.observations, 1, obs_indices).reshape(
self.n_env * batch_size, self.n_obs
)
actions = torch.gather(self.actions, 1, act_indices).reshape(
self.n_env * batch_size, self.n_act
)
if self.asymmetric_obs:
critic_obs_indices = indices.unsqueeze(-1).expand(
-1, -1, self.n_critic_obs
)
critic_observations = torch.gather(
self.critic_observations, 1, critic_obs_indices
).reshape(self.n_env * batch_size, self.n_critic_obs)
# Create sequential indices for each sample
# This creates a [n_env, batch_size, n_step] tensor of indices
seq_offsets = torch.arange(self.n_steps, device=self.device).view(1, 1, -1)
all_indices = (
indices.unsqueeze(-1) + seq_offsets
) % self.buffer_size # [n_env, batch_size, n_step]
# Gather all rewards and terminal flags
# Using advanced indexing - result shapes: [n_env, batch_size, n_step]
all_rewards = torch.gather(
self.rewards.unsqueeze(-1).expand(-1, -1, self.n_steps), 1, all_indices
)
all_dones = torch.gather(
self.dones.unsqueeze(-1).expand(-1, -1, self.n_steps), 1, all_indices
)
all_truncations = torch.gather(
self.truncations.unsqueeze(-1).expand(-1, -1, self.n_steps),
1,
all_indices,
)
# Create masks for rewards after first done
# This creates a cumulative product that zeroes out rewards after the first done
done_masks = torch.cumprod(
1.0 - all_dones, dim=2
) # [n_env, batch_size, n_step]
# Create discount factors
discounts = torch.pow(
self.gamma, torch.arange(self.n_steps, device=self.device)
) # [n_steps]
# Apply masks and discounts to rewards
masked_rewards = all_rewards * done_masks # [n_env, batch_size, n_step]
discounted_rewards = masked_rewards * discounts.view(
1, 1, -1
) # [n_env, batch_size, n_step]
# Sum rewards along the n_step dimension
n_step_rewards = discounted_rewards.sum(dim=2) # [n_env, batch_size]
# Find index of first done or truncation or last step for each sequence
first_done = torch.argmax(
(all_dones > 0).float(), dim=2
) # [n_env, batch_size]
first_trunc = torch.argmax(
(all_truncations > 0).float(), dim=2
) # [n_env, batch_size]
# Handle case where there are no dones or truncations
no_dones = all_dones.sum(dim=2) == 0
no_truncs = all_truncations.sum(dim=2) == 0
# When no dones or truncs, use the last index
first_done = torch.where(no_dones, self.n_steps - 1, first_done)
first_trunc = torch.where(no_truncs, self.n_steps - 1, first_trunc)
# Take the minimum (first) of done or truncation
final_indices = torch.minimum(
first_done, first_trunc
) # [n_env, batch_size]
# Create indices to gather the final next observations
final_next_obs_indices = torch.gather(
all_indices, 2, final_indices.unsqueeze(-1)
).squeeze(
-1
) # [n_env, batch_size]
# Gather final values
final_next_observations = self.next_observations.gather(
1, final_next_obs_indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
)
final_dones = self.dones.gather(1, final_next_obs_indices)
final_truncations = self.truncations.gather(1, final_next_obs_indices)
if self.asymmetric_obs:
final_next_critic_observations = self.next_critic_observations.gather(
1,
final_next_obs_indices.unsqueeze(-1).expand(
-1, -1, self.n_critic_obs
),
)
# Reshape everything to batch dimension
rewards = n_step_rewards.reshape(self.n_env * batch_size)
dones = final_dones.reshape(self.n_env * batch_size)
truncations = final_truncations.reshape(self.n_env * batch_size)
next_observations = final_next_observations.reshape(
self.n_env * batch_size, self.n_obs
)
if self.asymmetric_obs:
next_critic_observations = final_next_critic_observations.reshape(
self.n_env * batch_size, self.n_critic_obs
)
out = TensorDict(
{
"observations": observations,
"actions": actions,
"next": {
"rewards": rewards,
"dones": dones,
"truncations": truncations,
"observations": next_observations,
},
},
batch_size=self.n_env * batch_size,
)
if self.asymmetric_obs:
out["critic_observations"] = critic_observations
out["next"]["critic_observations"] = next_critic_observations
return out
class EmpiricalNormalization(nn.Module):
"""Normalize mean and variance of values based on empirical values."""
def __init__(self, shape, device, eps=1e-2, until=None):
"""Initialize EmpiricalNormalization module.
Args:
shape (int or tuple of int): Shape of input values except batch axis.
eps (float): Small value for stability.
until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes
exceeds it.
"""
super().__init__()
self.eps = eps
self.until = until
self.device = device
self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0).to(device))
self.register_buffer("_var", torch.ones(shape).unsqueeze(0).to(device))
self.register_buffer("_std", torch.ones(shape).unsqueeze(0).to(device))
self.register_buffer("count", torch.tensor(0, dtype=torch.long).to(device))
@property
def mean(self):
return self._mean.squeeze(0).clone()
@property
def std(self):
return self._std.squeeze(0).clone()
def forward(self, x: torch.Tensor, center: bool = True) -> torch.Tensor:
if x.shape[1:] != self._mean.shape[1:]:
raise ValueError(
f"Expected input of shape (*,{self._mean.shape[1:]}), got {x.shape}"
)
if self.training:
self.update(x)
if center:
return (x - self._mean) / (self._std + self.eps)
else:
return x / (self._std + self.eps)
@torch.jit.unused
def update(self, x):
"""Learn input values using Welford's online algorithm"""
if self.until is not None and self.count >= self.until:
return
batch_size = x.shape[0]
batch_mean = torch.mean(x, dim=0, keepdim=True)
# Update count
new_count = self.count + batch_size
# Update mean
delta = batch_mean - self._mean
self._mean += (batch_size / new_count) * delta
# Update variance using Welford's parallel algorithm
if self.count > 0: # Ensure we're not dividing by zero
# Compute batch variance
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
# Combine variances using parallel algorithm
delta2 = batch_mean - self._mean
m_a = self._var * self.count
m_b = batch_var * batch_size
M2 = m_a + m_b + (delta2**2) * (self.count * batch_size / new_count)
self._var = M2 / new_count
else:
# For first batch, just use batch variance
self._var = torch.mean((x - self._mean) ** 2, dim=0, keepdim=True)
self._std = torch.sqrt(self._var)
self.count = new_count
@torch.jit.unused
def inverse(self, y):
return y * (self._std + self.eps) + self._mean
def cpu_state(sd):
# detach & move to host without locking the compute stream
return {k: v.detach().to("cpu", non_blocking=True) for k, v in sd.items()}
def save_params(
global_step,
actor,
qnet,
qnet_target,
obs_normalizer,
critic_obs_normalizer,
args,
save_path,
):
"""Save model parameters and training configuration to disk."""
os.makedirs(os.path.dirname(save_path), exist_ok=True)
save_dict = {
"actor_state_dict": cpu_state(actor.state_dict()),
"qnet_state_dict": cpu_state(qnet.state_dict()),
"qnet_target_state_dict": cpu_state(qnet_target.state_dict()),
"obs_normalizer_state": (
cpu_state(obs_normalizer.state_dict())
if hasattr(obs_normalizer, "state_dict")
else None
),
"critic_obs_normalizer_state": (
cpu_state(critic_obs_normalizer.state_dict())
if hasattr(critic_obs_normalizer, "state_dict")
else None
),
"args": vars(args), # Save all arguments
"global_step": global_step,
}
torch.save(save_dict, save_path, _use_new_zipfile_serialization=True)
print(f"Saved parameters and configuration to {save_path}")

454
fast_td3/hyperparams.py Normal file
View File

@ -0,0 +1,454 @@
import os
from dataclasses import dataclass
import tyro
@dataclass
class BaseArgs:
# Default hyperparameters -- specifically for HumanoidBench
# See MuJoCoPlaygroundArgs for default hyperparameters for MuJoCo Playground
# See IsaacLabArgs for default hyperparameters for IsaacLab
env_name: str = "h1hand-stand-v0"
"""the id of the environment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
device_rank: int = 0
"""the rank of the device"""
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
project: str = "FastTD3"
"""the project name"""
use_wandb: bool = True
"""whether to use wandb"""
checkpoint_path: str = None
"""the path to the checkpoint file"""
num_envs: int = 128
"""the number of environments to run in parallel"""
num_eval_envs: int = 128
"""the number of evaluation environments to run in parallel (only valid for MuJoCo Playground)"""
total_timesteps: int = 150000
"""total timesteps of the experiments"""
critic_learning_rate: float = 3e-4
"""the learning rate of the critic"""
actor_learning_rate: float = 3e-4
"""the learning rate for the actor"""
buffer_size: int = 1024 * 50
"""the replay memory buffer size"""
num_steps: int = 1
"""the number of steps to use for the multi-step return"""
gamma: float = 0.99
"""the discount factor gamma"""
tau: float = 0.1
"""target smoothing coefficient (default: 0.005)"""
batch_size: int = 32768
"""the batch size of sample from the replay memory"""
policy_noise: float = 0.001
"""the scale of policy noise"""
std_min: float = 0.001
"""the minimum scale of noise"""
std_max: float = 0.4
"""the maximum scale of noise"""
learning_starts: int = 10
"""timestep to start learning"""
policy_frequency: int = 2
"""the frequency of training policy (delayed)"""
noise_clip: float = 0.5
"""noise clip parameter of the Target Policy Smoothing Regularization"""
num_updates: int = 2
"""the number of updates to perform per step"""
init_scale: float = 0.01
"""the scale of the initial parameters"""
num_atoms: int = 101
"""the number of atoms"""
v_min: float = -250.0
"""the minimum value of the support"""
v_max: float = 250.0
"""the maximum value of the support"""
critic_hidden_dim: int = 1024
"""the hidden dimension of the critic network"""
actor_hidden_dim: int = 512
"""the hidden dimension of the actor network"""
use_cdq: bool = True
"""whether to use Clipped Double Q-learning"""
measure_burnin: int = 3
"""Number of burn-in iterations for speed measure."""
eval_interval: int = 5000
"""the interval to evaluate the model"""
render_interval: int = 5000
"""the interval to render the model"""
compile: bool = True
"""whether to use torch.compile."""
obs_normalization: bool = True
"""whether to enable observation normalization"""
max_grad_norm: float = 0.0
"""the maximum gradient norm"""
amp: bool = True
"""whether to use amp"""
amp_dtype: str = "bf16"
"""the dtype of the amp"""
disable_bootstrap: bool = False
"""Whether to disable bootstrap in the critic learning"""
use_domain_randomization: bool = False
"""(Playground only) whether to use domain randomization"""
use_push_randomization: bool = False
"""(Playground only) whether to use push randomization"""
use_tuned_reward: bool = False
"""(Playground only) Use tuned reward for G1"""
action_bounds: float = 1.0
"""(IsaacLab only) the bounds of the action space (-action_bounds, action_bounds)"""
weight_decay: float = 0.1
"""the weight decay of the optimizer"""
save_interval: int = 5000
"""the interval to save the model"""
def get_args():
"""
Parse command-line arguments and return the appropriate Args instance based on env_name.
"""
# First, parse all arguments using the base Args class
base_args = tyro.cli(BaseArgs)
# Map environment names to their specific Args classes
# For tasks not here, default hyperparameters are used
# See below links for available task list
# - HumanoidBench (https://arxiv.org/abs/2403.10506)
# - IsaacLab (https://isaac-sim.github.io/IsaacLab/main/source/overview/environments.html)
# - MuJoCo Playground (https://arxiv.org/abs/2502.08844)
env_to_args_class = {
# HumanoidBench
# NOTE: These tasks are not full list of HumanoidBench tasks
"h1hand-reach-v0": H1HandReachArgs,
"h1hand-balance-simple-v0": H1HandBalanceSimpleArgs,
"h1hand-balance-hard-v0": H1HandBalanceHardArgs,
"h1hand-pole-v0": H1HandPoleArgs,
"h1hand-truck-v0": H1HandTruckArgs,
"h1hand-maze-v0": H1HandMazeArgs,
"h1hand-push-v0": H1HandPushArgs,
"h1hand-basketball-v0": H1HandBasketballArgs,
"h1hand-window-v0": H1HandWindowArgs,
"h1hand-package-v0": H1HandPackageArgs,
"h1hand-truck-v0": H1HandTruckArgs,
# MuJoCo Playground
# NOTE: These tasks are not full list of MuJoCo Playground tasks
"G1JoystickFlatTerrain": G1JoystickFlatTerrainArgs,
"G1JoystickRoughTerrain": G1JoystickRoughTerrainArgs,
"T1JoystickFlatTerrain": T1JoystickFlatTerrainArgs,
"T1JoystickRoughTerrain": T1JoystickRoughTerrainArgs,
"LeapCubeReorient": LeapCubeReorientArgs,
"LeapCubeRotateZAxis": LeapCubeRotateZAxisArgs,
"Go1JoystickFlatTerrain": Go1JoystickFlatTerrainArgs,
"Go1JoystickRoughTerrain": Go1JoystickRoughTerrainArgs,
"Go1Getup": Go1GetupArgs,
"CheetahRun": CheetahRunArgs, # NOTE: Example config for DeepMind Control Suite
# IsaacLab
# NOTE: These tasks are not full list of IsaacLab tasks
"Isaac-Lift-Cube-Franka-v0": IsaacLiftCubeFrankaArgs,
"Isaac-Open-Drawer-Franka-v0": IsaacOpenDrawerFrankaArgs,
"Isaac-Velocity-Flat-H1-v0": IsaacVelocityFlatH1Args,
"Isaac-Velocity-Flat-G1-v0": IsaacVelocityFlatG1Args,
"Isaac-Velocity-Rough-H1-v0": IsaacVelocityRoughH1Args,
"Isaac-Velocity-Rough-G1-v0": IsaacVelocityRoughG1Args,
"Isaac-Repose-Cube-Allegro-Direct-v0": IsaacReposeCubeAllegroDirectArgs,
"Isaac-Repose-Cube-Shadow-Direct-v0": IsaacReposeCubeShadowDirectArgs,
}
# If the provided env_name has a specific Args class, use it
if base_args.env_name in env_to_args_class:
specific_args_class = env_to_args_class[base_args.env_name]
# Re-parse with the specific class, maintaining any user overrides
specific_args = tyro.cli(specific_args_class)
return specific_args
if base_args.env_name.startswith("h1hand-") or base_args.env_name.startswith("h1-"):
# HumanoidBench
specific_args = tyro.cli(HumanoidBenchArgs)
elif base_args.env_name.startswith("Isaac-"):
# IsaacLab
specific_args = tyro.cli(IsaacLabArgs)
else:
# MuJoCo Playground
specific_args = tyro.cli(MuJoCoPlaygroundArgs)
return specific_args
@dataclass
class HumanoidBenchArgs(BaseArgs):
# See HumanoidBench (https://arxiv.org/abs/2403.10506) for available task list
total_timesteps: int = 100000
@dataclass
class H1HandReachArgs(HumanoidBenchArgs):
env_name: str = "h1hand-reach-v0"
v_min: float = -2000.0
v_max: float = 2000.0
@dataclass
class H1HandBalanceSimpleArgs(HumanoidBenchArgs):
env_name: str = "h1hand-balance-simple-v0"
total_timesteps: int = 200000
@dataclass
class H1HandBalanceHardArgs(HumanoidBenchArgs):
env_name: str = "h1hand-balance-hard-v0"
total_timesteps: int = 1000000
@dataclass
class H1HandPoleArgs(HumanoidBenchArgs):
env_name: str = "h1hand-pole-v0"
total_timesteps: int = 150000
@dataclass
class H1HandTruckArgs(HumanoidBenchArgs):
env_name: str = "h1hand-truck-v0"
total_timesteps: int = 500000
@dataclass
class H1HandMazeArgs(HumanoidBenchArgs):
env_name: str = "h1hand-maze-v0"
v_min: float = -1000.0
v_max: float = 1000.0
@dataclass
class H1HandPushArgs(HumanoidBenchArgs):
env_name: str = "h1hand-push-v0"
v_min: float = -1000.0
v_max: float = 1000.0
total_timesteps: int = 1000000
@dataclass
class H1HandBasketballArgs(HumanoidBenchArgs):
env_name: str = "h1hand-basketball-v0"
v_min: float = -2000.0
v_max: float = 2000.0
total_timesteps: int = 250000
@dataclass
class H1HandWindowArgs(HumanoidBenchArgs):
env_name: str = "h1hand-window-v0"
total_timesteps: int = 250000
@dataclass
class H1HandPackageArgs(HumanoidBenchArgs):
env_name: str = "h1hand-package-v0"
v_min: float = -10000.0
v_max: float = 10000.0
@dataclass
class H1HandTruckArgs(HumanoidBenchArgs):
env_name: str = "h1hand-truck-v0"
v_min: float = -1000.0
v_max: float = 1000.0
@dataclass
class MuJoCoPlaygroundArgs(BaseArgs):
# Default hyperparameters for many of Playground environments
v_min: float = -10.0
v_max: float = 10.0
buffer_size: int = 1024 * 10
num_envs: int = 1024
num_eval_envs: int = 1024
gamma: float = 0.97
@dataclass
class G1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "G1JoystickFlatTerrain"
total_timesteps: int = 100000
@dataclass
class G1JoystickRoughTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "G1JoystickRoughTerrain"
total_timesteps: int = 100000
@dataclass
class T1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "T1JoystickFlatTerrain"
total_timesteps: int = 100000
@dataclass
class T1JoystickRoughTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "T1JoystickRoughTerrain"
total_timesteps: int = 100000
@dataclass
class T1LowDofJoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "T1LowDofJoystickFlatTerrain"
total_timesteps: int = 1000000
@dataclass
class T1LowDofJoystickRoughTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "T1LowDofJoystickRoughTerrain"
total_timesteps: int = 1000000
@dataclass
class CheetahRunArgs(MuJoCoPlaygroundArgs):
# NOTE: This config will work for most DMC tasks, though we haven't tested DMC extensively.
# Future research can consider using LayerNorm as we find it sometimes works better for DMC tasks.
env_name: str = "CheetahRun"
num_steps: int = 3
v_min: float = -500.0
v_max: float = 500.0
std_min: float = 0.1
policy_noise: float = 0.1
@dataclass
class Go1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "Go1JoystickFlatTerrain"
total_timesteps: int = 50000
std_min: float = 0.2
std_max: float = 0.8
policy_noise: float = 0.2
num_updates: int = 8
@dataclass
class Go1JoystickRoughTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "Go1JoystickRoughTerrain"
total_timesteps: int = 50000
std_min: float = 0.2
std_max: float = 0.8
policy_noise: float = 0.2
num_updates: int = 8
@dataclass
class Go1GetupArgs(MuJoCoPlaygroundArgs):
env_name: str = "Go1Getup"
total_timesteps: int = 50000
std_min: float = 0.2
std_max: float = 0.8
policy_noise: float = 0.2
num_updates: int = 8
@dataclass
class LeapCubeReorientArgs(MuJoCoPlaygroundArgs):
env_name: str = "LeapCubeReorient"
num_steps: int = 3
policy_noise: float = 0.2
v_min: float = -50.0
v_max: float = 50.0
use_cdq: bool = False
@dataclass
class LeapCubeRotateZAxisArgs(MuJoCoPlaygroundArgs):
env_name: str = "LeapCubeRotateZAxis"
num_steps: int = 1
policy_noise: float = 0.2
v_min: float = -10.0
v_max: float = 10.0
use_cdq: bool = False
@dataclass
class IsaacLabArgs(BaseArgs):
v_min: float = -10.0
v_max: float = 10.0
buffer_size: int = 1024 * 10
num_envs: int = 4096
num_eval_envs: int = 4096
action_bounds: float = 1.0
std_max: float = 0.4
num_atoms: int = 251
render_interval: int = 0 # IsaacLab does not support rendering in our codebase
total_timesteps: int = 100000
@dataclass
class IsaacLiftCubeFrankaArgs(IsaacLabArgs):
# Value learning is unstable for Lift Cube task Due to brittle reward shaping
# Therefore, we need to disable bootstrap from 'reset_obs' in IsaacLab
# Higher UTD works better for manipulation tasks
env_name: str = "Isaac-Lift-Cube-Franka-v0"
num_updates: int = 8
v_min: float = -50.0
v_max: float = 50.0
std_max: float = 0.8
num_envs: int = 1024
num_eval_envs: int = 1024
action_bounds: float = 3.0
disable_bootstrap: bool = True
total_timesteps: int = 20000
@dataclass
class IsaacOpenDrawerFrankaArgs(IsaacLabArgs):
# Higher UTD works better for manipulation tasks
env_name: str = "Isaac-Open-Drawer-Franka-v0"
v_min: float = -50.0
v_max: float = 50.0
num_updates: int = 8
action_bounds: float = 3.0
total_timesteps: int = 20000
@dataclass
class IsaacVelocityFlatH1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Flat-H1-v0"
num_steps: int = 3
total_timesteps: int = 75000
@dataclass
class IsaacVelocityFlatG1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Flat-G1-v0"
num_steps: int = 3
total_timesteps: int = 50000
@dataclass
class IsaacVelocityRoughH1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Rough-H1-v0"
num_steps: int = 3
buffer_size: int = 1024 * 5 # To reduce memory usage
total_timesteps: int = 50000
@dataclass
class IsaacVelocityRoughG1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Rough-G1-v0"
num_steps: int = 3
buffer_size: int = 1024 * 5 # To reduce memory usage
total_timesteps: int = 50000
@dataclass
class IsaacReposeCubeAllegroDirectArgs(IsaacLabArgs):
env_name: str = "Isaac-Repose-Cube-Allegro-Direct-v0"
total_timesteps: int = 100000
v_min: float = -500.0
v_max: float = 500.0
@dataclass
class IsaacReposeCubeShadowDirectArgs(IsaacLabArgs):
env_name: str = "Isaac-Repose-Cube-Shadow-Direct-v0"
total_timesteps: int = 100000
v_min: float = -500.0
v_max: float = 500.0

602
fast_td3/train.py Normal file
View File

@ -0,0 +1,602 @@
import os
import sys
os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
if sys.platform != "darwin":
os.environ["MUJOCO_GL"] = "egl"
else:
os.environ["MUJOCO_GL"] = "glfw"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest"
import random
import time
import tqdm
import wandb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.amp import autocast, GradScaler
from tensordict import TensorDict, from_module
from fast_td3_utils import EmpiricalNormalization, SimpleReplayBuffer, save_params
from hyperparams import get_args
from fast_td3 import Actor, Critic
torch.set_float32_matmul_precision("high")
try:
import jax.numpy as jnp
except ImportError:
pass
def main():
args = get_args()
print(args)
run_name = f"{args.env_name}__{args.exp_name}__{args.seed}"
amp_enabled = args.amp and args.cuda and torch.cuda.is_available()
amp_device_type = (
"cuda"
if args.cuda and torch.cuda.is_available()
else "mps" if args.cuda and torch.backends.mps.is_available() else "cpu"
)
amp_dtype = torch.bfloat16 if args.amp_dtype == "bf16" else torch.float16
scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16)
if args.use_wandb:
wandb.init(
project=args.project,
name=run_name,
config=vars(args),
save_code=True,
)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
if not args.cuda:
device = torch.device("cpu")
else:
if torch.cuda.is_available():
device = torch.device(f"cuda:{args.device_rank}")
elif torch.backends.mps.is_available():
device = torch.device(f"mps:{args.device_rank}")
else:
raise ValueError("No GPU available")
print(f"Using device: {device}")
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
from environments.humanoid_bench_env import HumanoidBenchEnv
env_type = "humanoid_bench"
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
eval_envs = envs
render_env = HumanoidBenchEnv(
args.env_name, 1, render_mode="rgb_array", device=device
)
elif args.env_name.startswith("Isaac-"):
from environments.isaaclab_env import IsaacLabEnv
env_type = "isaaclab"
envs = IsaacLabEnv(
args.env_name,
device.type,
args.num_envs,
args.seed,
action_bounds=args.action_bounds,
)
eval_envs = envs
render_env = envs
else:
from environments.mujoco_playground_env import make_env
# TODO: Check if re-using same envs for eval could reduce memory usage
env_type = "mujoco_playground"
envs, eval_envs, render_env = make_env(
args.env_name,
args.seed,
args.num_envs,
args.num_eval_envs,
args.device_rank,
use_tuned_reward=args.use_tuned_reward,
use_domain_randomization=args.use_domain_randomization,
use_push_randomization=args.use_push_randomization,
)
n_act = envs.num_actions
n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]
if envs.asymmetric_obs:
n_critic_obs = (
envs.num_privileged_obs
if type(envs.num_privileged_obs) == int
else envs.num_privileged_obs[0]
)
else:
n_critic_obs = n_obs
action_low, action_high = -1.0, 1.0
if args.obs_normalization:
obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device)
critic_obs_normalizer = EmpiricalNormalization(
shape=n_critic_obs, device=device
)
else:
obs_normalizer = nn.Identity()
critic_obs_normalizer = nn.Identity()
actor = Actor(
n_obs=n_obs,
n_act=n_act,
num_envs=args.num_envs,
device=device,
init_scale=args.init_scale,
hidden_dim=args.actor_hidden_dim,
)
actor_detach = Actor(
n_obs=n_obs,
n_act=n_act,
num_envs=args.num_envs,
device=device,
init_scale=args.init_scale,
hidden_dim=args.actor_hidden_dim,
)
# Copy params to actor_detach without grad
from_module(actor).data.to_module(actor_detach)
policy = actor_detach.explore
qnet = Critic(
n_obs=n_critic_obs,
n_act=n_act,
num_atoms=args.num_atoms,
v_min=args.v_min,
v_max=args.v_max,
hidden_dim=args.critic_hidden_dim,
device=device,
)
qnet_target = Critic(
n_obs=n_critic_obs,
n_act=n_act,
num_atoms=args.num_atoms,
v_min=args.v_min,
v_max=args.v_max,
hidden_dim=args.critic_hidden_dim,
device=device,
)
qnet_target.load_state_dict(qnet.state_dict())
q_optimizer = optim.AdamW(
list(qnet.parameters()),
lr=args.critic_learning_rate,
weight_decay=args.weight_decay,
)
actor_optimizer = optim.AdamW(
list(actor.parameters()),
lr=args.actor_learning_rate,
weight_decay=args.weight_decay,
)
rb = SimpleReplayBuffer(
n_env=args.num_envs,
buffer_size=args.buffer_size,
n_obs=n_obs,
n_act=n_act,
n_critic_obs=n_critic_obs,
asymmetric_obs=envs.asymmetric_obs,
n_steps=args.num_steps,
gamma=args.gamma,
device=device,
)
policy_noise = args.policy_noise
noise_clip = args.noise_clip
def evaluate():
obs_normalizer.eval()
num_eval_envs = eval_envs.num_envs
episode_returns = torch.zeros(num_eval_envs, device=device)
episode_lengths = torch.zeros(num_eval_envs, device=device)
done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device)
if env_type == "isaaclab":
obs = eval_envs.reset(random_start_init=False)
else:
obs = eval_envs.reset()
# Run for a fixed number of steps
for _ in range(eval_envs.max_episode_steps):
with torch.no_grad(), autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
obs = normalize_obs(obs)
actions = actor(obs)
next_obs, rewards, dones, _ = eval_envs.step(actions.float())
episode_returns = torch.where(
~done_masks, episode_returns + rewards, episode_returns
)
episode_lengths = torch.where(
~done_masks, episode_lengths + 1, episode_lengths
)
done_masks = torch.logical_or(done_masks, dones)
if done_masks.all():
break
obs = next_obs
obs_normalizer.train()
return episode_returns.mean().item(), episode_lengths.mean().item()
def render_with_rollout():
obs_normalizer.eval()
# Quick rollout for rendering
if env_type == "humanoid_bench":
obs = render_env.reset()
renders = [render_env.render()]
elif env_type == "isaaclab":
raise NotImplementedError(
"We don't support rendering for IsaacLab environments"
)
else:
obs = render_env.reset()
render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]])
renders = [render_env.state]
for i in range(render_env.max_episode_steps):
with torch.no_grad(), autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
obs = normalize_obs(obs)
actions = actor(obs)
next_obs, _, done, _ = render_env.step(actions.float())
if env_type == "mujoco_playground":
render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]])
if i % 2 == 0:
if env_type == "humanoid_bench":
renders.append(render_env.render())
else:
renders.append(render_env.state)
if done.any():
break
obs = next_obs
if env_type == "mujoco_playground":
renders = render_env.render_trajectory(renders)
obs_normalizer.train()
return renders
def update_main(data, logs_dict):
with autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
observations = data["observations"]
next_observations = data["next"]["observations"]
if envs.asymmetric_obs:
critic_observations = data["critic_observations"]
next_critic_observations = data["next"]["critic_observations"]
else:
critic_observations = observations
next_critic_observations = next_observations
actions = data["actions"]
rewards = data["next"]["rewards"]
dones = data["next"]["dones"].bool()
truncations = data["next"]["truncations"].bool()
if args.disable_bootstrap:
bootstrap = (~dones).float()
else:
bootstrap = (truncations | ~dones).float()
clipped_noise = torch.randn_like(actions)
clipped_noise = clipped_noise.mul(policy_noise).clamp(
-noise_clip, noise_clip
)
next_state_actions = (actor(next_observations) + clipped_noise).clamp(
action_low, action_high
)
with torch.no_grad():
qf1_next_target_projected, qf2_next_target_projected = (
qnet_target.projection(
next_critic_observations,
next_state_actions,
rewards,
bootstrap,
args.gamma,
)
)
qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)
qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected)
if args.use_cdq:
qf_next_target_dist = torch.where(
qf1_next_target_value.unsqueeze(1)
< qf2_next_target_value.unsqueeze(1),
qf1_next_target_projected,
qf2_next_target_projected,
)
qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist
else:
qf1_next_target_dist, qf2_next_target_dist = (
qf1_next_target_projected,
qf2_next_target_projected,
)
qf1, qf2 = qnet(critic_observations, actions)
qf1_loss = -torch.sum(
qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1
).mean()
qf2_loss = -torch.sum(
qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1
).mean()
qf_loss = qf1_loss + qf2_loss
q_optimizer.zero_grad(set_to_none=True)
scaler.scale(qf_loss).backward()
scaler.unscale_(q_optimizer)
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
qnet.parameters(),
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
)
scaler.step(q_optimizer)
scaler.update()
logs_dict["buffer_rewards"] = rewards.mean()
logs_dict["critic_grad_norm"] = critic_grad_norm.detach()
logs_dict["qf_loss"] = qf_loss.detach()
logs_dict["qf_max"] = qf1_next_target_value.max().detach()
logs_dict["qf_min"] = qf1_next_target_value.min().detach()
return logs_dict
def update_pol(data, logs_dict):
with autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
critic_observations = (
data["critic_observations"]
if envs.asymmetric_obs
else data["observations"]
)
qf1, qf2 = qnet(critic_observations, actor(data["observations"]))
qf1_value = qnet.get_value(F.softmax(qf1, dim=1))
qf2_value = qnet.get_value(F.softmax(qf2, dim=1))
if args.use_cdq:
qf_value = torch.minimum(qf1_value, qf2_value)
else:
qf_value = (qf1_value + qf2_value) / 2.0
actor_loss = -qf_value.mean()
actor_optimizer.zero_grad(set_to_none=True)
scaler.scale(actor_loss).backward()
scaler.unscale_(actor_optimizer)
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
actor.parameters(),
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
)
scaler.step(actor_optimizer)
scaler.update()
logs_dict["actor_grad_norm"] = actor_grad_norm.detach()
logs_dict["actor_loss"] = actor_loss.detach()
return logs_dict
if args.compile:
mode = None
update_main = torch.compile(update_main, mode=mode)
update_pol = torch.compile(update_pol, mode=mode)
policy = torch.compile(policy, mode=mode)
normalize_obs = torch.compile(obs_normalizer.forward, mode=mode)
normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode)
else:
normalize_obs = obs_normalizer.forward
normalize_critic_obs = critic_obs_normalizer.forward
if envs.asymmetric_obs:
obs, critic_obs = envs.reset_with_critic_obs()
critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float)
else:
obs = envs.reset()
if args.checkpoint_path:
# Load checkpoint if specified
torch_checkpoint = torch.load(
f"{args.checkpoint_path}", map_location=device, weights_only=False
)
actor.load_state_dict(torch_checkpoint["actor_state_dict"])
obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"])
critic_obs_normalizer.load_state_dict(
torch_checkpoint["critic_obs_normalizer_state"]
)
qnet.load_state_dict(torch_checkpoint["qnet_state_dict"])
qnet_target.load_state_dict(torch_checkpoint["qnet_target_state_dict"])
global_step = torch_checkpoint["global_step"]
else:
global_step = 0
dones = None
pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step)
start_time = None
desc = ""
while global_step < args.total_timesteps:
logs_dict = TensorDict()
if (
start_time is None
and global_step >= args.measure_burnin + args.learning_starts
):
start_time = time.time()
measure_burnin = global_step
with torch.no_grad(), autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
norm_obs = normalize_obs(obs)
actions = policy(obs=norm_obs, dones=dones)
next_obs, rewards, dones, infos = envs.step(actions.float())
truncations = infos["time_outs"]
if envs.asymmetric_obs:
next_critic_obs = infos["observations"]["critic"]
# Compute 'true' next_obs and next_critic_obs for saving
true_next_obs = torch.where(
dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs
)
if envs.asymmetric_obs:
true_next_critic_obs = torch.where(
dones[:, None] > 0,
infos["observations"]["raw"]["critic_obs"],
next_critic_obs,
)
transition = TensorDict(
{
"observations": obs,
"actions": torch.as_tensor(actions, device=device, dtype=torch.float),
"next": {
"observations": true_next_obs,
"rewards": torch.as_tensor(
rewards, device=device, dtype=torch.float
),
"truncations": truncations.long(),
"dones": dones.long(),
},
},
batch_size=(envs.num_envs,),
device=device,
)
if envs.asymmetric_obs:
transition["critic_observations"] = critic_obs
transition["next"]["critic_observations"] = true_next_critic_obs
obs = next_obs
if envs.asymmetric_obs:
critic_obs = next_critic_obs
rb.extend(transition)
batch_size = args.batch_size // args.num_envs
if global_step > args.learning_starts:
for i in range(args.num_updates):
data = rb.sample(batch_size)
data["observations"] = normalize_obs(data["observations"])
data["next"]["observations"] = normalize_obs(
data["next"]["observations"]
)
if envs.asymmetric_obs:
data["critic_observations"] = normalize_critic_obs(
data["critic_observations"]
)
data["next"]["critic_observations"] = normalize_critic_obs(
data["next"]["critic_observations"]
)
logs_dict = update_main(data, logs_dict)
if args.num_updates > 1:
if i % args.policy_frequency == 1:
logs_dict = update_pol(data, logs_dict)
else:
if global_step % args.policy_frequency == 0:
logs_dict = update_pol(data, logs_dict)
for param, target_param in zip(
qnet.parameters(), qnet_target.parameters()
):
target_param.data.copy_(
args.tau * param.data + (1 - args.tau) * target_param.data
)
if global_step % 100 == 0 and start_time is not None:
speed = (global_step - measure_burnin) / (time.time() - start_time)
pbar.set_description(f"{speed: 4.4f} sps, " + desc)
with torch.no_grad():
logs = {
"actor_loss": logs_dict["actor_loss"].mean(),
"qf_loss": logs_dict["qf_loss"].mean(),
"qf_max": logs_dict["qf_max"].mean(),
"qf_min": logs_dict["qf_min"].mean(),
"actor_grad_norm": logs_dict["actor_grad_norm"].mean(),
"critic_grad_norm": logs_dict["critic_grad_norm"].mean(),
"buffer_rewards": logs_dict["buffer_rewards"].mean(),
"env_rewards": rewards.mean(),
}
if args.eval_interval > 0 and global_step % args.eval_interval == 0:
print(f"Evaluating at global step {global_step}")
eval_avg_return, eval_avg_length = evaluate()
if env_type in ["humanoid_bench", "isaaclab"]:
# NOTE: Hacky way of evaluating performance, but just works
obs = envs.reset()
logs["eval_avg_return"] = eval_avg_return
logs["eval_avg_length"] = eval_avg_length
if (
args.render_interval > 0
and global_step % args.render_interval == 0
):
renders = render_with_rollout()
if args.use_wandb:
wandb.log(
{
"render_video": wandb.Video(
np.array(renders).transpose(
0, 3, 1, 2
), # Convert to (T, C, H, W) format
fps=30,
format="gif",
)
},
step=global_step,
)
if args.use_wandb:
wandb.log(
{
"speed": speed,
"frame": global_step * args.num_envs,
**logs,
},
step=global_step,
)
if (
args.save_interval > 0
and global_step > 0
and global_step % args.save_interval == 0
):
print(f"Saving model at global step {global_step}")
save_params(
global_step,
actor,
qnet,
qnet_target,
obs_normalizer,
critic_obs_normalizer,
args,
f"models/{run_name}_{global_step}.pt",
)
global_step += 1
pbar.update(1)
save_params(
global_step,
actor,
qnet,
qnet_target,
obs_normalizer,
critic_obs_normalizer,
args,
f"models/{run_name}_final.pt",
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,785 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# FastTD3 Training Notebook\n",
"\n",
"Welcome! This notebook will let you execute a series of code blocks that enables you to experience how FastTD3 works -- each block will import packages, define arguments, create environments, create FastTD3 agent, and train the agent.\n",
"\n",
"This notebook also provide the same functionalities as `train.py` -- you can use this notebook to train your own agents, upload logs to wandb, render rollouts, and fine-tune pre-trained agents with more environment steps!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Set environment variables and import packages\n",
"\n",
"import os\n",
"\n",
"os.environ[\"TORCHDYNAMO_INLINE_INBUILT_NN_MODULES\"] = \"1\"\n",
"os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n",
"if sys.platform != \"darwin\":\n",
" os.environ[\"MUJOCO_GL\"] = \"egl\"\n",
"else:\n",
" os.environ[\"MUJOCO_GL\"] = \"glfw\"\n",
"os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n",
"os.environ[\"JAX_DEFAULT_MATMUL_PRECISION\"] = \"highest\"\n",
"\n",
"import random\n",
"import time\n",
"\n",
"import tqdm\n",
"import wandb\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.amp import autocast, GradScaler\n",
"from tensordict import TensorDict, from_module\n",
"\n",
"torch.set_float32_matmul_precision(\"high\")\n",
"\n",
"from fast_td3_utils import (\n",
" EmpiricalNormalization,\n",
" SimpleReplayBuffer,\n",
" save_params,\n",
")\n",
"\n",
"from fast_td3 import Critic, Actor"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Set checkpoint if you want to fine-tune from existing checkpoint\n",
"# e.g., set checkpoint to \"models/h1-walk-v0_notebook_experiment_30000.pt\"\n",
"checkpoint_path = None"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Customize arguments as needed\n",
"# However, IsaacLab may not work in Notebook Setup.\n",
"# We recommend using HumanoidBench or MuJoCo Playground for notebook experiments.\n",
"\n",
"# For quick experiments, let's use a task without dexterous hands\n",
"# But for your research, we recommend using `h1hand` tasks in HumanoidBench!\n",
"from hyperparams import HumanoidBenchArgs\n",
"\n",
"args = HumanoidBenchArgs(\n",
" env_name=\"h1-walk-v0\",\n",
" total_timesteps=20000,\n",
" render_interval=5000,\n",
" eval_interval=5000,\n",
")\n",
"run_name = f\"{args.env_name}_notebook_experiment\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: GPU-Related Configurations\n",
"\n",
"amp_enabled = args.amp and args.cuda and torch.cuda.is_available()\n",
"amp_device_type = (\n",
" \"cuda\"\n",
" if args.cuda and torch.cuda.is_available()\n",
" else \"mps\" if args.cuda and torch.backends.mps.is_available() else \"cpu\"\n",
")\n",
"amp_dtype = torch.bfloat16 if args.amp_dtype == \"bf16\" else torch.float16\n",
"\n",
"scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16)\n",
"\n",
"if not args.cuda:\n",
" device = torch.device(\"cpu\")\n",
"else:\n",
" if torch.cuda.is_available():\n",
" device = torch.device(f\"cuda:{args.device_rank}\")\n",
" elif torch.backends.mps.is_available():\n",
" device = torch.device(f\"mps:{args.device_rank}\")\n",
" else:\n",
" raise ValueError(\"No GPU available\")\n",
"print(f\"Using device: {device}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Define Wandb if needed\n",
"\n",
"# Set use_wandb to True if you want to use Wandb\n",
"use_wandb = True\n",
"\n",
"if use_wandb:\n",
" wandb.init(\n",
" project=\"FastTD3\",\n",
" name=run_name,\n",
" config=vars(args),\n",
" save_code=True,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Initialize Environment and Related Variables\n",
"\n",
"if args.env_name.startswith(\"h1hand-\") or args.env_name.startswith(\"h1-\"):\n",
" from environments.humanoid_bench_env import HumanoidBenchEnv\n",
"\n",
" env_type = \"humanoid_bench\"\n",
" envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)\n",
" eval_envs = envs\n",
" render_env = HumanoidBenchEnv(args.env_name, 1, render_mode=\"rgb_array\", device=device)\n",
"elif args.env_name.startswith(\"Isaac-\"):\n",
" from environments.isaaclab_env import IsaacLabEnv\n",
"\n",
" env_type = \"isaaclab\"\n",
" envs = IsaacLabEnv(\n",
" args.env_name,\n",
" device.type,\n",
" args.num_envs,\n",
" args.seed,\n",
" action_bounds=args.action_bounds,\n",
" )\n",
" eval_envs = envs\n",
" render_envs = envs\n",
"else:\n",
" from environments.mujoco_playground_env import make_env\n",
" import jax.numpy as jnp\n",
"\n",
" env_type = \"mujoco_playground\"\n",
" envs, eval_envs, render_env = make_env(\n",
" args.env_name,\n",
" args.seed,\n",
" args.num_envs,\n",
" args.num_eval_envs,\n",
" args.device_rank,\n",
" use_tuned_reward=args.use_tuned_reward,\n",
" use_domain_randomization=args.use_domain_randomization,\n",
" )\n",
"\n",
"n_act = envs.num_actions\n",
"n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]\n",
"if envs.asymmetric_obs:\n",
" n_critic_obs = (\n",
" envs.num_privileged_obs\n",
" if type(envs.num_privileged_obs) == int\n",
" else envs.num_privileged_obs[0]\n",
" )\n",
"else:\n",
" n_critic_obs = n_obs\n",
"action_low, action_high = -1.0, 1.0"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Initialize Normalizer, Actor, and Critic\n",
"\n",
"if args.obs_normalization:\n",
" obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device)\n",
" critic_obs_normalizer = EmpiricalNormalization(shape=n_critic_obs, device=device)\n",
"else:\n",
" obs_normalizer = nn.Identity()\n",
" critic_obs_normalizer = nn.Identity()\n",
"\n",
"normalize_obs = obs_normalizer.forward\n",
"normalize_critic_obs = critic_obs_normalizer.forward\n",
"\n",
"# Actor setup\n",
"actor = Actor(\n",
" n_obs=n_obs,\n",
" n_act=n_act,\n",
" num_envs=args.num_envs,\n",
" device=device,\n",
" init_scale=args.init_scale,\n",
" hidden_dim=args.actor_hidden_dim,\n",
")\n",
"actor_detach = Actor(\n",
" n_obs=n_obs,\n",
" n_act=n_act,\n",
" num_envs=args.num_envs,\n",
" device=device,\n",
" init_scale=args.init_scale,\n",
" hidden_dim=args.actor_hidden_dim,\n",
")\n",
"# Copy params to actor_detach without grad\n",
"from_module(actor).data.to_module(actor_detach)\n",
"policy = actor_detach.explore\n",
"\n",
"qnet = Critic(\n",
" n_obs=n_critic_obs,\n",
" n_act=n_act,\n",
" num_atoms=args.num_atoms,\n",
" v_min=args.v_min,\n",
" v_max=args.v_max,\n",
" hidden_dim=args.critic_hidden_dim,\n",
" device=device,\n",
")\n",
"qnet_target = Critic(\n",
" n_obs=n_critic_obs,\n",
" n_act=n_act,\n",
" num_atoms=args.num_atoms,\n",
" v_min=args.v_min,\n",
" v_max=args.v_max,\n",
" hidden_dim=args.critic_hidden_dim,\n",
" device=device,\n",
")\n",
"qnet_target.load_state_dict(qnet.state_dict())\n",
"\n",
"q_optimizer = optim.AdamW(\n",
" list(qnet.parameters()),\n",
" lr=args.critic_learning_rate,\n",
" weight_decay=args.weight_decay,\n",
")\n",
"actor_optimizer = optim.AdamW(\n",
" list(actor.parameters()),\n",
" lr=args.actor_learning_rate,\n",
" weight_decay=args.weight_decay,\n",
")\n",
"\n",
"rb = SimpleReplayBuffer(\n",
" n_env=args.num_envs,\n",
" buffer_size=args.buffer_size,\n",
" n_obs=n_obs,\n",
" n_act=n_act,\n",
" n_critic_obs=n_critic_obs,\n",
" asymmetric_obs=envs.asymmetric_obs,\n",
" n_steps=args.num_steps,\n",
" gamma=args.gamma,\n",
" device=device,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Define Evaluation & Rendering Functions\n",
"\n",
"\n",
"def evaluate():\n",
" obs_normalizer.eval()\n",
" num_eval_envs = eval_envs.num_envs\n",
" episode_returns = torch.zeros(num_eval_envs, device=device)\n",
" episode_lengths = torch.zeros(num_eval_envs, device=device)\n",
" done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device)\n",
"\n",
" if env_type == \"isaaclab\":\n",
" obs = eval_envs.reset(random_start_init=False)\n",
" else:\n",
" obs = eval_envs.reset()\n",
"\n",
" # Run for a fixed number of steps\n",
" for _ in range(eval_envs.max_episode_steps):\n",
" with torch.no_grad(), autocast(\n",
" device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n",
" ):\n",
" obs = normalize_obs(obs)\n",
" actions = actor(obs)\n",
"\n",
" next_obs, rewards, dones, _ = eval_envs.step(actions.float())\n",
" episode_returns = torch.where(\n",
" ~done_masks, episode_returns + rewards, episode_returns\n",
" )\n",
" episode_lengths = torch.where(~done_masks, episode_lengths + 1, episode_lengths)\n",
" done_masks = torch.logical_or(done_masks, dones)\n",
" if done_masks.all():\n",
" break\n",
" obs = next_obs\n",
"\n",
" obs_normalizer.train()\n",
" return episode_returns.mean().item(), episode_lengths.mean().item()\n",
"\n",
"\n",
"def render_with_rollout():\n",
" obs_normalizer.eval()\n",
"\n",
" # Quick rollout for rendering\n",
" if env_type == \"humanoid_bench\":\n",
" obs = render_env.reset()\n",
" renders = [render_env.render()]\n",
" elif env_type == \"isaaclab\":\n",
" raise NotImplementedError(\n",
" \"We don't support rendering for IsaacLab environments\"\n",
" )\n",
" else:\n",
" obs = render_env.reset()\n",
" render_env.state.info[\"command\"] = jnp.array([[1.0, 0.0, 0.0]])\n",
" renders = [render_env.state]\n",
" for i in range(render_env.max_episode_steps):\n",
" with torch.no_grad(), autocast(\n",
" device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n",
" ):\n",
" obs = normalize_obs(obs)\n",
" actions = actor(obs)\n",
" next_obs, _, done, _ = render_env.step(actions.float())\n",
" if env_type == \"mujoco_playground\":\n",
" render_env.state.info[\"command\"] = jnp.array([[1.0, 0.0, 0.0]])\n",
" if i % 2 == 0:\n",
" if env_type == \"humanoid_bench\":\n",
" renders.append(render_env.render())\n",
" else:\n",
" renders.append(render_env.state)\n",
" if done.any():\n",
" break\n",
" obs = next_obs\n",
"\n",
" if env_type == \"mujoco_playground\":\n",
" renders = render_env.render_trajectory(renders)\n",
"\n",
" obs_normalizer.train()\n",
" return renders"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Define Update Functions\n",
"\n",
"policy_noise = args.policy_noise\n",
"noise_clip = args.noise_clip\n",
"\n",
"\n",
"def update_main(data, logs_dict):\n",
" with autocast(device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled):\n",
" observations = data[\"observations\"]\n",
" next_observations = data[\"next\"][\"observations\"]\n",
" if envs.asymmetric_obs:\n",
" critic_observations = data[\"critic_observations\"]\n",
" next_critic_observations = data[\"next\"][\"critic_observations\"]\n",
" else:\n",
" critic_observations = observations\n",
" next_critic_observations = next_observations\n",
" actions = data[\"actions\"]\n",
" rewards = data[\"next\"][\"rewards\"]\n",
" dones = data[\"next\"][\"dones\"].bool()\n",
" truncations = data[\"next\"][\"truncations\"].bool()\n",
" if args.disable_bootstrap:\n",
" bootstrap = (~dones).float()\n",
" else:\n",
" bootstrap = (truncations | ~dones).float()\n",
"\n",
" clipped_noise = torch.randn_like(actions)\n",
" clipped_noise = clipped_noise.mul(policy_noise).clamp(-noise_clip, noise_clip)\n",
"\n",
" next_state_actions = (actor(next_observations) + clipped_noise).clamp(\n",
" action_low, action_high\n",
" )\n",
"\n",
" with torch.no_grad():\n",
" qf1_next_target_projected, qf2_next_target_projected = (\n",
" qnet_target.projection(\n",
" next_critic_observations,\n",
" next_state_actions,\n",
" rewards,\n",
" bootstrap,\n",
" args.gamma,\n",
" )\n",
" )\n",
" qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)\n",
" qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected)\n",
" if args.use_cdq:\n",
" qf_next_target_dist = torch.where(\n",
" qf1_next_target_value.unsqueeze(1)\n",
" < qf2_next_target_value.unsqueeze(1),\n",
" qf1_next_target_projected,\n",
" qf2_next_target_projected,\n",
" )\n",
" qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist\n",
" else:\n",
" qf1_next_target_dist, qf2_next_target_dist = (\n",
" qf1_next_target_projected,\n",
" qf2_next_target_projected,\n",
" )\n",
"\n",
" qf1, qf2 = qnet(critic_observations, actions)\n",
" qf1_loss = -torch.sum(\n",
" qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1\n",
" ).mean()\n",
" qf2_loss = -torch.sum(\n",
" qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1\n",
" ).mean()\n",
" qf_loss = qf1_loss + qf2_loss\n",
"\n",
" q_optimizer.zero_grad(set_to_none=True)\n",
" scaler.scale(qf_loss).backward()\n",
" scaler.unscale_(q_optimizer)\n",
"\n",
" critic_grad_norm = torch.nn.utils.clip_grad_norm_(\n",
" qnet.parameters(),\n",
" max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float(\"inf\"),\n",
" )\n",
" scaler.step(q_optimizer)\n",
" scaler.update()\n",
"\n",
" logs_dict[\"buffer_rewards\"] = rewards.mean()\n",
" logs_dict[\"critic_grad_norm\"] = critic_grad_norm.detach()\n",
" logs_dict[\"qf_loss\"] = qf_loss.detach()\n",
" logs_dict[\"qf_max\"] = qf1_next_target_value.max().detach()\n",
" logs_dict[\"qf_min\"] = qf1_next_target_value.min().detach()\n",
" return logs_dict\n",
"\n",
"\n",
"def update_pol(data, logs_dict):\n",
" with autocast(device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled):\n",
" critic_observations = (\n",
" data[\"critic_observations\"] if envs.asymmetric_obs else data[\"observations\"]\n",
" )\n",
"\n",
" qf1, qf2 = qnet(critic_observations, actor(data[\"observations\"]))\n",
" qf1_value = qnet.get_value(F.softmax(qf1, dim=1))\n",
" qf2_value = qnet.get_value(F.softmax(qf2, dim=1))\n",
" if args.use_cdq:\n",
" qf_value = torch.minimum(qf1_value, qf2_value)\n",
" else:\n",
" qf_value = (qf1_value + qf2_value) / 2.0\n",
" actor_loss = -qf_value.mean()\n",
"\n",
" actor_optimizer.zero_grad(set_to_none=True)\n",
" scaler.scale(actor_loss).backward()\n",
" scaler.unscale_(actor_optimizer)\n",
" actor_grad_norm = torch.nn.utils.clip_grad_norm_(\n",
" actor.parameters(),\n",
" max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float(\"inf\"),\n",
" )\n",
" scaler.step(actor_optimizer)\n",
" scaler.update()\n",
" logs_dict[\"actor_grad_norm\"] = actor_grad_norm.detach()\n",
" logs_dict[\"actor_loss\"] = actor_loss.detach()\n",
" return logs_dict"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Compile Functions if Needed\n",
"\n",
"if args.compile:\n",
" mode = None\n",
" update_main = torch.compile(update_main, mode=mode)\n",
" update_pol = torch.compile(update_pol, mode=mode)\n",
" policy = torch.compile(policy, mode=mode)\n",
" normalize_obs = torch.compile(normalize_obs, mode=mode)\n",
" normalize_critic_obs = torch.compile(normalize_critic_obs, mode=mode)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Load Checkpoint if Needed\n",
"if checkpoint_path is not None:\n",
" torch_checkpoint = torch.load(\n",
" f\"{checkpoint_path}\", map_location=device, weights_only=False\n",
" )\n",
"\n",
" actor.load_state_dict(torch_checkpoint[\"actor_state_dict\"])\n",
" obs_normalizer.load_state_dict(torch_checkpoint[\"obs_normalizer_state\"])\n",
" critic_obs_normalizer.load_state_dict(\n",
" torch_checkpoint[\"critic_obs_normalizer_state\"]\n",
" )\n",
" qnet.load_state_dict(torch_checkpoint[\"qnet_state_dict\"])\n",
" qnet_target.load_state_dict(torch_checkpoint[\"qnet_target_state_dict\"])\n",
" global_step = torch_checkpoint[\"global_step\"]\n",
"else:\n",
" global_step = 0"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Utility functions for displaying videos in notebook\n",
"\n",
"from IPython.display import display, HTML\n",
"import base64\n",
"import imageio\n",
"import tempfile\n",
"import os\n",
"\n",
"\n",
"def frames_to_video_html(frames, fps=30):\n",
" \"\"\"\n",
" Convert a list of numpy arrays to an HTML5 video element.\n",
"\n",
" Args:\n",
" frames (list): List of numpy arrays representing video frames\n",
" fps (int): Frames per second for the video\n",
"\n",
" Returns:\n",
" HTML object containing the video element\n",
" \"\"\"\n",
" # Create a temporary file to store the video\n",
" with tempfile.NamedTemporaryFile(suffix=\".mp4\", delete=False) as temp_file:\n",
" temp_filename = temp_file.name\n",
"\n",
" # Save frames as video\n",
" imageio.mimsave(temp_filename, frames, fps=fps)\n",
"\n",
" # Read the video file and encode it to base64\n",
" with open(temp_filename, \"rb\") as f:\n",
" video_data = f.read()\n",
" video_b64 = base64.b64encode(video_data).decode(\"utf-8\")\n",
"\n",
" # Create HTML video element\n",
" video_html = f\"\"\"\n",
" <video width=\"640\" height=\"480\" controls>\n",
" <source src=\"data:video/mp4;base64,{video_b64}\" type=\"video/mp4\">\n",
" Your browser does not support the video tag.\n",
" </video>\n",
" \"\"\"\n",
"\n",
" # Clean up the temporary file\n",
" os.unlink(temp_filename)\n",
"\n",
" return HTML(video_html)\n",
"\n",
"\n",
"def update_video_display(frames, fps=30):\n",
" \"\"\"\n",
" Display video frames as an embedded HTML5 video element.\n",
"\n",
" Args:\n",
" frames (list): List of numpy arrays representing video frames\n",
" fps (int): Frames per second for the video\n",
" \"\"\"\n",
" video_html = frames_to_video_html(frames, fps=fps)\n",
" display(video_html)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NOTE: Main Training Loop\n",
"\n",
"if envs.asymmetric_obs:\n",
" obs, critic_obs = envs.reset_with_critic_obs()\n",
" critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float)\n",
"else:\n",
" obs = envs.reset()\n",
"pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step)\n",
"\n",
"dones = None\n",
"while global_step < args.total_timesteps:\n",
" logs_dict = TensorDict()\n",
" with torch.no_grad(), autocast(\n",
" device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n",
" ):\n",
" norm_obs = normalize_obs(obs)\n",
" actions = policy(obs=norm_obs, dones=dones)\n",
"\n",
" next_obs, rewards, dones, infos = envs.step(actions.float())\n",
" truncations = infos[\"time_outs\"]\n",
"\n",
" if envs.asymmetric_obs:\n",
" next_critic_obs = infos[\"observations\"][\"critic\"]\n",
"\n",
" # Compute 'true' next_obs and next_critic_obs for saving\n",
" true_next_obs = torch.where(\n",
" dones[:, None] > 0, infos[\"observations\"][\"raw\"][\"obs\"], next_obs\n",
" )\n",
" if envs.asymmetric_obs:\n",
" true_next_critic_obs = torch.where(\n",
" dones[:, None] > 0,\n",
" infos[\"observations\"][\"raw\"][\"critic_obs\"],\n",
" next_critic_obs,\n",
" )\n",
" transition = TensorDict(\n",
" {\n",
" \"observations\": obs,\n",
" \"actions\": torch.as_tensor(actions, device=device, dtype=torch.float),\n",
" \"next\": {\n",
" \"observations\": true_next_obs,\n",
" \"rewards\": torch.as_tensor(rewards, device=device, dtype=torch.float),\n",
" \"truncations\": truncations.long(),\n",
" \"dones\": dones.long(),\n",
" },\n",
" },\n",
" batch_size=(envs.num_envs,),\n",
" device=device,\n",
" )\n",
" if envs.asymmetric_obs:\n",
" transition[\"critic_observations\"] = critic_obs\n",
" transition[\"next\"][\"critic_observations\"] = true_next_critic_obs\n",
"\n",
" obs = next_obs\n",
" if envs.asymmetric_obs:\n",
" critic_obs = next_critic_obs\n",
"\n",
" rb.extend(transition)\n",
"\n",
" batch_size = args.batch_size // args.num_envs\n",
" if global_step > args.learning_starts:\n",
" for i in range(args.num_updates):\n",
" data = rb.sample(batch_size)\n",
" data[\"observations\"] = normalize_obs(data[\"observations\"])\n",
" data[\"next\"][\"observations\"] = normalize_obs(data[\"next\"][\"observations\"])\n",
" if envs.asymmetric_obs:\n",
" data[\"critic_observations\"] = normalize_critic_obs(\n",
" data[\"critic_observations\"]\n",
" )\n",
" data[\"next\"][\"critic_observations\"] = normalize_critic_obs(\n",
" data[\"next\"][\"critic_observations\"]\n",
" )\n",
" logs_dict = update_main(data, logs_dict)\n",
" if args.num_updates > 1:\n",
" if i % args.policy_frequency == 1:\n",
" logs_dict = update_pol(data, logs_dict)\n",
" else:\n",
" if global_step % args.policy_frequency == 0:\n",
" logs_dict = update_pol(data, logs_dict)\n",
"\n",
" for param, target_param in zip(qnet.parameters(), qnet_target.parameters()):\n",
" target_param.data.copy_(\n",
" args.tau * param.data + (1 - args.tau) * target_param.data\n",
" )\n",
"\n",
" if global_step > 0 and global_step % 100 == 0:\n",
" with torch.no_grad():\n",
" logs = {\n",
" \"actor_loss\": logs_dict[\"actor_loss\"].mean(),\n",
" \"qf_loss\": logs_dict[\"qf_loss\"].mean(),\n",
" \"qf_max\": logs_dict[\"qf_max\"].mean(),\n",
" \"qf_min\": logs_dict[\"qf_min\"].mean(),\n",
" \"actor_grad_norm\": logs_dict[\"actor_grad_norm\"].mean(),\n",
" \"critic_grad_norm\": logs_dict[\"critic_grad_norm\"].mean(),\n",
" \"buffer_rewards\": logs_dict[\"buffer_rewards\"].mean(),\n",
" \"env_rewards\": rewards.mean(),\n",
" }\n",
"\n",
" if args.eval_interval > 0 and global_step % args.eval_interval == 0:\n",
" eval_avg_return, eval_avg_length = evaluate()\n",
" if env_type in [\"humanoid_bench\", \"isaaclab\"]:\n",
" # NOTE: Hacky way of evaluating performance, but just works\n",
" obs = envs.reset()\n",
" logs[\"eval_avg_return\"] = eval_avg_return\n",
" logs[\"eval_avg_length\"] = eval_avg_length\n",
"\n",
" if args.render_interval > 0 and global_step % args.render_interval == 0:\n",
" renders = render_with_rollout()\n",
" print_logs = {\n",
" k: v.item() if isinstance(v, torch.Tensor) else v\n",
" for k, v in logs.items()\n",
" }\n",
" for k, v in print_logs.items():\n",
" print(f\"{k}: {v:.4f}\")\n",
" update_video_display(renders, fps=30)\n",
" if use_wandb:\n",
" wandb.log(\n",
" {\n",
" \"render_video\": wandb.Video(\n",
" np.array(renders).transpose(\n",
" 0, 3, 1, 2\n",
" ), # Convert to (T, C, H, W) format\n",
" fps=30,\n",
" format=\"gif\",\n",
" )\n",
" },\n",
" step=global_step,\n",
" )\n",
" if use_wandb:\n",
" wandb.log(\n",
" {\n",
" \"frame\": global_step * args.num_envs,\n",
" **logs,\n",
" },\n",
" step=global_step,\n",
" )\n",
"\n",
" if (\n",
" args.save_interval > 0\n",
" and global_step > 0\n",
" and global_step % args.save_interval == 0\n",
" ):\n",
" save_params(\n",
" global_step,\n",
" actor,\n",
" qnet,\n",
" qnet_target,\n",
" obs_normalizer,\n",
" critic_obs_normalizer,\n",
" args,\n",
" f\"models/{run_name}_{global_step}.pt\",\n",
" )\n",
"\n",
" global_step += 1\n",
" pbar.update(1)\n",
"\n",
"save_params(\n",
" global_step,\n",
" actor,\n",
" qnet,\n",
" qnet_target,\n",
" obs_normalizer,\n",
" critic_obs_normalizer,\n",
" args,\n",
" f\"models/{run_name}_final.pt\",\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "fasttd3_hb",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.17"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,18 @@
gymnasium<1.0.0
jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11"
matplotlib
moviepy
numpy<2.0
pandas
protobuf
pygame
stable-baselines3
tqdm
wandb
torchrl==0.7.2
tensordict==0.7.2
tyro
loguru
torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124
torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124

View File

@ -0,0 +1,34 @@
gymnasium<1.0.0
jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11"
matplotlib
moviepy
numpy<2.0
pandas
protobuf
pygame
stable-baselines3
tqdm
wandb
torchrl==0.7.2
tensordict==0.7.2
tyro
loguru
git+https://github.com/younggyoseo/mujoco_playground.git
torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124
torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
jax[cuda12]==0.4.35
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvcc-cu12==12.8.93
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127

11
setup.py Normal file
View File

@ -0,0 +1,11 @@
from setuptools import setup, find_packages
setup(
name="fast_td3",
version="0.1.0",
description="FastTD3 implementation",
author="",
author_email="",
url="",
packages=find_packages(),
)

101
sim2real.md Normal file
View File

@ -0,0 +1,101 @@
# Guide for Sim2Real Training & Deployment
This guide provides guide to run sim-to-real experiments using FastTD3 and BoosterGym.
**⚠️ Warning:** Deploying RL policies to real hardware can be sometimes very dangerous. Please make sure that you understand everything, check your policies work well in simulation, set every robot configuration correct (e.g., damping, stiffness, torque limits, etc), and follow proper safety protocols. **Simply copy-pasting commands in this README is not safe**.
## ⚙️ Prerequisites
Install dependencies for Playground experiments (see `README.md`)
Then, install `fast_td3` package with `pip install -e .` so you can import its classes in BoosterGym (see `fast_td3_deploy.py`).
**⚠️ Note:** Our sim-to-real experiments depend on our customized MuJoCo Playground that supports `T1LowDimJoystick` tasks for 12-DOF T1 control instead of 23-DOF T1 control in `T1Joystick` tasks.
## 🚀 Training in simulation
Users can train deployable policies for Booster T1 with FastTD3 using the below script:
```bash
python fast_td3/train.py --env_name T1LowDimJoystickRoughTerrain --exp_name FastTD3 --use_domain_randomization --use_push_randomization --total_timesteps 1000000 --render_interval 0 --seed 2
```
**⚠️ Note:** There is no 'guaranteed' number of training steps that can ensure safe real-world deployment. Usually, the gait becomes more stable with longer training. Please check the quality of gaits via sim-to-sim transfer, and fine-tune the policy to fix the issues. Use the checkpoints in `models` directory for sim-to-sim or sim-to-real transfer.
**⚠️ Note:** We set `render_interval` to 0 to avoid dumping a lot of videos into wandb. Make sure to set it to non-zero values if you want to render videos during training.
### (Optional) 2-Stage Training
For faster convergence, users can consider introducing curriculum to the training -- so that the robot first learns to walk in a flat terrain without push perturbations. For this, train policies with the below script:
```bash
STAGE1_STEPS = 100000
STAGE2_STEPS = 300000 # Effective steps: 300000 - 200000 = 100000
SEED = 2
CHECKPOINT_PATH = T1LowDimJoystickFlatTerrain__FastTD3-Stage1__${SEED}_final.pt
conda activate fasttd3_playground
# Stage 1 training
python fast_td3/train.py --env_name T1LowDimJoystickFlatTerrain --exp_name FastTD3-Stage1 --use_domain_randomization --no_use_push_randomization --total_timesteps ${STAGE1_STEPS} --render_interval 0 --seed ${SEED}
# Stage 2 training
python fast_td3/train.py --env_name T1LowDimJoystickRoughTerrain --exp_name FastTD3-Stage2 --use_domain_randomization --use_push_randomization --total_timesteps ${STAGE2_STEPS} --render_interval 0 --checkpoint_path ${CHECKPOINT_PATH} --seed ${SEED}
```
Again, 100K and 200K steps do not guarantee safe real-world deployment. Please check the quality of gaits via sim-to-sim transfer, and fine-tune the policy to fix the issues. Use the final checkpoint (`models/T1LowDimJoystickRoughTerrain__FastTD3-Stage2__${SEED}_final.pt`) for sim-to-sim or sim-to-real transfer.
## 🛝 Deployment with BoosterGym
We use the customized version of [BoosterGym](https://github.com/BoosterRobotics/booster_gym) for deployment with FastTD3.
First, clone our fork of BoosterGym.
```bash
git clone https://github.com/carlosferrazza/booster_gym.git
```
Then, follow the [guide](https://github.com/carlosferrazza/booster_gym) to install dependencies for BoosterGym.
### Sim-to-Sim Transfer
You can check whether the trained policy transfers to non-MJX version of MuJoCo.
Use the following commands in a machine that supports rendering to test sim-to-sim transfer:
```bash
cd <YOUR_WORKSPACE>/booster_gym
# Activate your BoosterGym virtual environemnt
# Launch MuJoCo simulation
python play_mujoco.py --task=T1 --checkpoint=<CHECKPOINT_PATH>
mjpython play_mujoco.py --task=T1 --checkpoint=<CHECKPOINT_PATH> # for Mac
```
### Sim-to-Real Transfer
First, prepare a JIT-scripted checkpoint
```python
# Python snippets for JIT-scripting checkpoints
import torch
from fast_td3 import load_policy
policy = load_policy(<CHECKPOINT_PATH>)
scripted_policy = torch.jit.script(policy)
scripted_policy.save(<JIT_CHECKPOINT_PATH>)
```
Then, deploy this JIT-scripted checkpoint by following the guide on [Booster T1 Deployment](https://github.com/carlosferrazza/booster_gym/tree/main/deploy).
**⚠️ Warning:** Please double-check every value in robot configuration (`booster_gym/deploy/configs/T1.yaml`) is correctly set! If values for position control such as `damping` or `stiffness` are set differently, your robot may perform dangerous behaviors.
**⚠️ Warning:** You may want to use different configuration (e.g., `damping` and `stiffness`, etc) for your own experiments. Just make sure to thoroughly test it in simulation and make sure to set the values correctly.
---
🚀 That's it! Hope everything went smoothly, and be aware of your safety.