commit 258bfe67dd3446bb918c67f6fc250f9b55a98bb2 Author: Younggyo Seo Date: Thu May 29 01:49:23 2025 +0000 Initial Public Release diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..26c1979 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +media +models +wandb +figures +visualize.ipynb +record.ipynb +*.pyc +.ipynb_checkpoints +fast_td3.egg-info/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..868341e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,5 @@ +repos: +- repo: https://github.com/psf/black + rev: stable + hooks: + - id: black \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..475b014 --- /dev/null +++ b/LICENSE @@ -0,0 +1,315 @@ +MIT License + +Copyright (c) 2025 Younggyo Seo + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +-------------------------------------------------------------------------------- + +MIT License + +Copyright (c) 2024 LeanRL developers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +-------------------------------------------------------------------------------- +Code in `cleanrl/ddpg_continuous_action.py` and `cleanrl/td3_continuous_action.py` are adapted from https://github.com/sfujim/TD3 + +MIT License + +Copyright (c) 2020 Scott Fujimoto + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +-------------------------------------------------------------------------------- +Code in `cleanrl/sac_continuous_action.py` is inspired and adapted from [haarnoja/sac](https://github.com/haarnoja/sac), [openai/spinningup](https://github.com/openai/spinningup), [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic), [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3), and [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac). + +- [haarnoja/sac](https://github.com/haarnoja/sac/blob/8258e33633c7e37833cc39315891e77adfbe14b2/LICENSE.txt) + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2017, 2018 The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2017, 2018, the respective contributors +All rights reserved. + +SAC uses a shared copyright model: each contributor holds copyright over +their contributions to the SAC codebase. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the SAC repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. + +- [openai/spinningup](https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/LICENSE) + +The MIT License + +Copyright (c) 2018 OpenAI (http://openai.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +- [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3/blob/44e53ff8115e8f4bff1d5218f10c8c7d1a4cfc12/LICENSE) + +The MIT License + +Copyright (c) 2019 Antonin Raffin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +- [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac/blob/81c5b536d3a1c5616b2531e446450df412a064fb/LICENSE) + +MIT License + +Copyright (c) 2019 Denis Yarats + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +- [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic/blob/master/LICENSE) + +MIT License + +Copyright (c) 2018 Pranjal Tandon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +--------------------------------------------------------------------------------- +The CONTRIBUTING.md is adopted from https://github.com/entity-neural-network/incubator/blob/2a0c38b30828df78c47b0318c76a4905020618dd/CONTRIBUTING.md +and https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/CONTRIBUTING.md + +MIT License + +Copyright (c) 2021 Entity Neural Network developers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + + +MIT License + +Copyright (c) 2020 Stable-Baselines Team + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +--------------------------------------------------------------------------------- +The cleanrl/ppo_continuous_action_isaacgym.py is contributed by Nvidia + +SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: MIT + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. + +-------------------------------------------------------------------------------- + +Code in `cleanrl/qdagger_dqn_atari_impalacnn.py` and `cleanrl/qdagger_dqn_atari_jax_impalacnn.py` are adapted from https://github.com/google-research/reincarnating_rl + +**NOTE: the original repo did not fill out the copyright section in their license +so the following copyright notice is copied as is per the license requirement. +See https://github.com/google-research/reincarnating_rl/blob/a1d402f48a9f8658ca6aa0ddf416ab391745ff2c/LICENSE#L189 + + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..cce027f --- /dev/null +++ b/README.md @@ -0,0 +1,238 @@ +# FastTD3 - Simple and Fast RL for Humanoid Control + +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +[![arXiv](https://img.shields.io/badge/arXiv-2505.22642-b31b1b.svg)](https://arxiv.org/abs/2505.22642) + + +FastTD3 is a high-performance variant of the Twin Delayed Deep Deterministic Policy Gradient (TD3) algorithm, optimized for complex humanoid control tasks. FastTD3 can solve various humanoid control tasks with dexterous hands from HumanoidBench in just a few hours of training. Furthermore, FastTD3 achieves similar or better wall-time-efficiency to PPO in high-dimensional control tasks from popular simulations such as IsaacLab and MuJoCo Playground. + +For more information, please see our [project webpage](https://younggyo.me/fast_td3) + +## ✨ Features + +FastTD3 offers researchers a significant speedup in training complex humanoid agents. + +- Ready-to-go codebase with detailed instructions and pre-configured hyperparameters for each task +- Support popular benchmarks: HumanoidBench, MuJoCo Playground, and IsaacLab +- User-friendly features that can accelerate your research, such as rendering rollouts, torch optimizations (AMP and compile), and saving and loading checkpoints + +## ⚙️ Prerequisites + +Before you begin, ensure you have the following installed: +- Conda (for environment management) +- Git LFS (Large File Storage) -- For IsaacLab +- CMake -- For IsaacLab + +And the following system packages: +```bash +sudo apt install libglfw3 libgl1-mesa-glx libosmesa6 git-lfs cmake +``` + +## 📖 Installation + +This project requires different Conda environments for different sets of experiments. + +### Common Setup +First, ensure the common dependencies are installed as mentioned in the [Prerequisites](#prerequisites) section. + +### Environment for HumanoidBench + +```bash +conda create -n fasttd3_hb -y python=3.10 +conda activate fasttd3_hb +pip install --editable git+https://github.com/carlosferrazza/humanoid-bench.git#egg=humanoid-bench +pip install -r requirements/requirements.txt +``` + +### Environment for MuJoCo Playground +```bash +conda create -n fasttd3_playground -y python=3.10 +conda activate fasttd3_playground +pip install -r requirements/requirements_playground.txt +``` + +**⚠️ Note:** Our `requirements_playground.txt` specifies `Jax==0.4.35`, which we found to be stable for latest GPUs in certain tasks such as `LeapCubeReorient` or `LeapCubeRotateZAxis` + +**⚠️ Note:** Current FastTD3 codebase uses customized MuJoCo Playground that supports saving last observations into info dictionary. We will work on incorporating this change into official repository hopefully soon. + +### Environment for IsaacLab +```bash +conda create -n fasttd3_isaaclab -y python=3.10 +conda activate fasttd3_isaaclab + +# Install IsaacLab (refer to official documentation for the latest steps) +# Official Quickstart: https://isaac-sim.github.io/IsaacLab/main/source/setup/quickstart.html +pip install 'isaacsim[all,extscache]==4.5.0' --extra-index-url https://pypi.nvidia.com +git clone https://github.com/isaac-sim/IsaacLab.git +cd IsaacLab +./isaaclab.sh --install +cd .. + +# Install project-specific requirements +pip install -r requirements/requirements.txt +``` + +### (Optional) Accelerate headless GPU rendering in cloud instances + +In some cloud VM images the NVIDIA kernel driver is present but the user-space OpenGL/EGL/Vulkan libraries aren't, so MuJoCo falls back to CPU renderer. You can install just the NVIDIA user-space libraries (and skip rebuilding the kernel module) with: + +```bash +sudo apt install -y kmod +sudo sh NVIDIA-Linux-x86_64-.run -s --no-kernel-module --ui=none --no-questions +``` + +As a rule-of-thumb, if you're running experiments and rendering is taking longer than 5 seconds, it is very likely that GPU renderer is not used. + +## 🚀 Running Experiments + +Activate the appropriate Conda environment before running experiments. + +Please see `fast_td3/hyperparams.py` for information regarding hyperparameters! + +### HumanoidBench Experiments +```bash +conda activate fasttd3_hb +python fast_td3/train.py --env_name h1hand-hurdle-v0 --exp_name FastTD3 --render_interval 5000 --seed 1 +``` + +### MuJoCo Playground Experiments +```bash +conda activate fasttd3_playground +python fast_td3/train.py --env_name G1JoystickRoughTerrain --exp_name FastTD3 --render_interval 5000 --seed 1 +``` + +### IsaacLab Experiments +```bash +conda activate fasttd3_isaaclab +python fast_td3/train.py --env_name Isaac-Velocity-Flat-G1-v0 --exp_name FastTD3 --render_interval 0 --seed 1 +python fast_td3/train.py --env_name Isaac-Repose-Cube-Allegro-Direct-v0 --exp_name FastTD3 --render_interval 0 --seed 1 +``` + +**Quick note:** For boolean-based arguments, you can set them to False by adding `no_` in front each argument, for instance, if you want to disable Clipped Q Learning, you can specify `--no_use_cdq` in your command. + +## 💡 Performance-Related Tips + +We used a single Nvidia A100 80GB GPU for all experiments. Here are some remarks and tips for improving performances in your setup or troubleshooting in your machine configurations. + +- *Sample-efficiency* tends to improve with larger `num_envs`, `num_updates`, and `batch_size`. But this comes at the cost of *Time-efficiency*. Our default settings are optimized for wall-time efficiency on a single A100 80GB GPU. If you're using a different setup, consider tuning hyperparameters accordingly. +- When FastTD3 performance is stuck at local minima at the early phase of training in your experiments + - First consider increasing the `num_updates`. This happens usually when the agent fails to exploit value functions. We also find higher `num_updates` tends to be helpful for relatively easier tasks or tasks with low-dimensional action spaces. + - If the agent is completely stuck or much worse than your expectation, try using `num_steps=3` or disabling `use_cdq`. + - For tasks that have penalty reward terms (e.g., torques, energy, action_rate, ..), consider lowering them for initial experiments, and tune the values. In some cases, curriculum learning with lower penalty terms followed by fine-tuning with stronger terms is effective. +- When you encounter out-of-memory error with your GPU, our recommendation for reducing GPU usage is (i) smaller `buffer_size`, (ii) smaller `batch_size`, and then (iii) smaller `num_envs`. Because our codebase is assigning the whole replay buffer in GPU to reduce CPU-GPU transfer bottleneck, it usually has the largest GPU consumption, but usually less harmful to reduce. + +## 🛝 Playing with the FastTD3 training + +A Jupyter notebook (`training_notebook.ipynb`) is available to help you get started with: +- Training FastTD3 agents. +- Loading pre-trained models. +- Visualizing agent behavior. +- Potentially, re-training or fine-tuning models. + +## 🤖 Sim-to-Real RL with FastTD3 + +We provide the [walkthrough](sim2real.md) for training deployable policies with FastTD3. + +## Contributing + +We welcome contributions! Please feel free to submit issues and pull requests. + +## License + +This project is licensed under the MIT License -- see the [LICENSE](LICENSE) file for details. Note that the repository relies on third-party libraries subject to their respective licenses. + +## Acknowledgements + +This codebase builds upon [LeanRL](https://github.com/pytorch-labs/LeanRL) framework. + +We would like to thank people who have helped throughout the project: + +- We thank [Kevin Zakka](https://kzakka.com/) for the help in setting up MuJoCo Playground. +- We thank [Changyeon Kim](https://changyeon.site/) for testing the early version of this codebase + +## Citations + +### FastTD3 +```bibtex +@article{seo2025fasttd3, + title = {Generative Image as Action Models}, + author = {Seo, Younggyo and Sferrazza, Carmelo and Geng, Haoran and Nauman, Michal and Yin, Zhao-Heng and Abbeel, Pieter}, + booktitle = {preprint}, + year = {2025}, +} +``` + +### TD3 +```bibtex +@inproceedings{fujimoto2018addressing, + title={Addressing function approximation error in actor-critic methods}, + author={Fujimoto, Scott and Hoof, Herke and Meger, David}, + booktitle={International conference on machine learning}, + pages={1587--1596}, + year={2018}, + organization={PMLR} +} +``` + +### LeanRL + +Following the [LeanRL](https://github.com/pytorch-labs/LeanRL)'s recommendation, we put CleanRL's bibtex here: + +```bibtex +@article{huang2022cleanrl, + author = {Shengyi Huang and Rousslan Fernand Julien Dossa and Chang Ye and Jeff Braga and Dipam Chakraborty and Kinal Mehta and João G.M. Araújo}, + title = {CleanRL: High-quality Single-file Implementations of Deep Reinforcement Learning Algorithms}, + journal = {Journal of Machine Learning Research}, + year = {2022}, + volume = {23}, + number = {274}, + pages = {1--18}, + url = {http://jmlr.org/papers/v23/21-1342.html} +} +``` + +### Parallel Q-Learning (PQL) +```bibtex +@inproceedings{li2023parallel, + title={Parallel $ Q $-Learning: Scaling Off-policy Reinforcement Learning under Massively Parallel Simulation}, + author={Li, Zechu and Chen, Tao and Hong, Zhang-Wei and Ajay, Anurag and Agrawal, Pulkit}, + booktitle={International Conference on Machine Learning}, + pages={19440--19459}, + year={2023}, + organization={PMLR} +} +``` + +### HumanoidBench +```bibtex +@inproceedings{sferrazza2024humanoidbench, + title={Humanoidbench: Simulated humanoid benchmark for whole-body locomotion and manipulation}, + author={Sferrazza, Carmelo and Huang, Dun-Ming and Lin, Xingyu and Lee, Youngwoon and Abbeel, Pieter}, + booktitle={Robotics: Science and Systems}, + year={2024} +} +``` + +### MuJoCo Playground +```bibtex +@article{zakka2025mujoco, + title={MuJoCo Playground}, + author={Zakka, Kevin and Tabanpour, Baruch and Liao, Qiayuan and Haiderbhai, Mustafa and Holt, Samuel and Luo, Jing Yuan and Allshire, Arthur and Frey, Erik and Sreenath, Koushil and Kahrs, Lueder A and others}, + journal={arXiv preprint arXiv:2502.08844}, + year={2025} +} +``` + +### IsaacLab +```bibtex +@article{mittal2023orbit, + author={Mittal, Mayank and Yu, Calvin and Yu, Qinxi and Liu, Jingzhou and Rudin, Nikita and Hoeller, David and Yuan, Jia Lin and Singh, Ritvik and Guo, Yunrong and Mazhar, Hammad and Mandlekar, Ajay and Babich, Buck and State, Gavriel and Hutter, Marco and Garg, Animesh}, + journal={IEEE Robotics and Automation Letters}, + title={Orbit: A Unified Simulation Framework for Interactive Robot Learning Environments}, + year={2023}, + volume={8}, + number={6}, + pages={3740-3747}, + doi={10.1109/LRA.2023.3270034} +} +``` \ No newline at end of file diff --git a/fast_td3/__init__.py b/fast_td3/__init__.py new file mode 100644 index 0000000..9688bfc --- /dev/null +++ b/fast_td3/__init__.py @@ -0,0 +1,20 @@ +""" +Fast TD3 is a high-performance implementation of Twin Delayed Deep Deterministic Policy Gradient (TD3) +with distributional critics for reinforcement learning. +""" + +# Core model components +from fast_td3.fast_td3 import Actor, Critic, DistributionalQNetwork +from fast_td3.fast_td3_utils import EmpiricalNormalization, SimpleReplayBuffer +from fast_td3.fast_td3_deploy import Policy, load_policy + +__all__ = [ + # Core model components + "Actor", + "Critic", + "DistributionalQNetwork", + "EmpiricalNormalization", + "SimpleReplayBuffer", + "Policy", + "load_policy", +] diff --git a/fast_td3/environments/humanoid_bench_env.py b/fast_td3/environments/humanoid_bench_env.py new file mode 100644 index 0000000..c1382f6 --- /dev/null +++ b/fast_td3/environments/humanoid_bench_env.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import gymnasium as gym + +import humanoid_bench +from gymnasium.wrappers import TimeLimit +from stable_baselines3.common.vec_env import SubprocVecEnv +import numpy as np +import torch +from loguru import logger as log + +# Disable all logging below CRITICAL level +log.remove() +log.add(lambda msg: False, level="CRITICAL") + + +def make_env(env_name, rank, render_mode=None, seed=0): + """ + Utility function for multiprocessed env. + + :param rank: (int) index of the subprocess + :param seed: (int) the inital seed for RNG + """ + + if env_name in [ + "h1hand-push-v0", + "h1-push-v0", + "h1hand-cube-v0", + "h1cube-v0", + "h1hand-basketball-v0", + "h1-basketball-v0", + "h1hand-kitchen-v0", + "h1-kitchen-v0", + ]: + max_episode_steps = 500 + else: + max_episode_steps = 1000 + + def _init(): + import humanoid_bench + + env = gym.make(env_name, render_mode=render_mode) + env = TimeLimit(env, max_episode_steps=max_episode_steps) + env.unwrapped.seed(seed + rank) + + return env + + return _init + + +class HumanoidBenchEnv: + """Wraps HumanoidBench environment to support parallel environments.""" + + def __init__(self, env_name, num_envs=1, render_mode=None, device=None): + # NOTE: HumanoidBench action space is already normalized to [-1, 1] + device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.sim_device = device + self.num_envs = num_envs + + # Create the base environment + self.envs = SubprocVecEnv( + [make_env(env_name, i, render_mode=render_mode) for i in range(num_envs)] + ) + + if env_name in [ + "h1hand-push-v0", + "h1-push-v0", + "h1hand-cube-v0", + "h1cube-v0", + "h1hand-basketball-v0", + "h1-basketball-v0", + "h1hand-kitchen-v0", + "h1-kitchen-v0", + ]: + self.max_episode_steps = 500 + else: + self.max_episode_steps = 1000 + + # For compatibility with MuJoCo Playground + self.asymmetric_obs = False # For comptatibility with MuJoCo Playground + self.num_obs = self.envs.observation_space.shape[-1] + self.num_actions = self.envs.action_space.shape[-1] + + def reset(self): + """Reset the environment.""" + observations = self.envs.reset() + observations = torch.from_numpy(observations).to( + device=self.sim_device, dtype=torch.float + ) + return observations + + def render(self): + assert ( + self.num_envs == 1 + ), "Currently only supports single environment rendering" + return self.envs.render() + + def step(self, actions): + assert isinstance(actions, torch.Tensor) + actions = actions.cpu().numpy() + + observations, rewards, dones, raw_infos = self.envs.step(actions) + + # This will be used for getting 'true' next observations + infos = dict() + infos["observations"] = {"raw": {"obs": observations.copy()}} + truncateds = np.zeros_like(dones) + for i in range(self.num_envs): + if raw_infos[i].get("TimeLimit.truncated", False): + truncateds[i] = True + infos["observations"]["raw"]["obs"][i] = raw_infos[i][ + "terminal_observation" + ] + + observations = torch.from_numpy(observations).to( + device=self.sim_device, dtype=torch.float + ) + rewards = torch.from_numpy(rewards).to( + device=self.sim_device, dtype=torch.float + ) + dones = torch.from_numpy(dones).to(device=self.sim_device) + truncateds = torch.from_numpy(truncateds).to(device=self.sim_device) + infos["observations"]["raw"]["obs"] = torch.from_numpy( + infos["observations"]["raw"]["obs"] + ).to(device=self.sim_device, dtype=torch.float) + infos["time_outs"] = truncateds + + return observations, rewards, dones, infos diff --git a/fast_td3/environments/isaaclab_env.py b/fast_td3/environments/isaaclab_env.py new file mode 100644 index 0000000..876fc69 --- /dev/null +++ b/fast_td3/environments/isaaclab_env.py @@ -0,0 +1,82 @@ +from typing import Optional + +import gymnasium as gym +import torch +from isaaclab.app import AppLauncher + +app_launcher = AppLauncher(headless=True) +simulation_app = app_launcher.app + +import isaaclab_tasks +from isaaclab_tasks.utils.parse_cfg import parse_env_cfg + + +class IsaacLabEnv: + """Wrapper for IsaacLab environments to be compatible with MuJoCo Playground""" + + def __init__( + self, + task_name: str, + device: str, + num_envs: int, + seed: int, + action_bounds: Optional[float] = None, + ): + env_cfg = parse_env_cfg( + task_name, + device=device, + num_envs=num_envs, + ) + env_cfg.seed = seed + self.seed = seed + self.envs = gym.make(task_name, cfg=env_cfg, render_mode=None) + + self.num_envs = self.envs.unwrapped.num_envs + self.max_episode_steps = self.envs.unwrapped.max_episode_length + self.action_bounds = action_bounds + self.num_obs = self.envs.unwrapped.single_observation_space["policy"].shape[0] + self.asymmetric_obs = "critic" in self.envs.unwrapped.single_observation_space + if self.asymmetric_obs: + self.num_privileged_obs = self.envs.unwrapped.single_observation_space[ + "critic" + ].shape[0] + else: + self.num_privileged_obs = 0 + self.num_actions = self.envs.unwrapped.single_action_space.shape[0] + + def reset(self, random_start_init: bool = True) -> torch.Tensor: + obs_dict, _ = self.envs.reset() + # NOTE: decorrelate episode horizons like RSL‑RL + if random_start_init: + self.envs.unwrapped.episode_length_buf = torch.randint_like( + self.envs.unwrapped.episode_length_buf, high=int(self.max_episode_steps) + ) + return obs_dict["policy"] + + def reset_with_critic_obs(self) -> tuple[torch.Tensor, torch.Tensor]: + obs_dict, _ = self.envs.reset() + return obs_dict["policy"], obs_dict["critic"] + + def step( + self, actions: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]: + if self.action_bounds is not None: + actions = torch.clamp(actions, -1.0, 1.0) * self.action_bounds + obs_dict, rew, terminations, truncations, infos = self.envs.step(actions) + dones = (terminations | truncations).to(dtype=torch.long) + obs = obs_dict["policy"] + critic_obs = obs_dict["critic"] if self.asymmetric_obs else None + info_ret = {"time_outs": truncations, "observations": {"critic": critic_obs}} + # NOTE: There's really no way to get the raw observations from IsaacLab + # We just use the 'reset_obs' as next_obs, unfortunately. + # See https://github.com/isaac-sim/IsaacLab/issues/1362 + info_ret["observations"]["raw"] = { + "obs": obs, + "critic_obs": critic_obs, + } + return obs, rew, dones, info_ret + + def render(self): + raise NotImplementedError( + "We don't support rendering for IsaacLab environments" + ) diff --git a/fast_td3/environments/mujoco_playground_env.py b/fast_td3/environments/mujoco_playground_env.py new file mode 100644 index 0000000..5e21999 --- /dev/null +++ b/fast_td3/environments/mujoco_playground_env.py @@ -0,0 +1,136 @@ +from mujoco_playground import registry +from mujoco_playground import wrapper_torch + +import jax +import mujoco + +jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) + + +class PlaygroundEvalEnvWrapper: + def __init__(self, eval_env, max_episode_steps, env_name, num_eval_envs, seed): + """ + Wrapper used for evaluation / rendering environments. + Note that this is different from training environments that are + wrapped with RSLRLBraxWrapper. + """ + self.env = eval_env + self.env_name = env_name + self.num_envs = num_eval_envs + self.jit_reset = jax.jit(jax.vmap(self.env.reset)) + self.jit_step = jax.jit(jax.vmap(self.env.step)) + + if isinstance(self.env.unwrapped.observation_size, dict): + self.asymmetric_obs = True + else: + self.asymmetric_obs = False + + self.key = jax.random.PRNGKey(seed) + self.key_reset = jax.random.split(self.key, num_eval_envs) + self.max_episode_steps = max_episode_steps + + def reset(self): + self.state = self.jit_reset(self.key_reset) + if self.asymmetric_obs: + obs = wrapper_torch._jax_to_torch(self.state.obs["state"]) + else: + obs = wrapper_torch._jax_to_torch(self.state.obs) + return obs + + def step(self, actions): + self.state = self.jit_step(self.state, wrapper_torch._torch_to_jax(actions)) + if self.asymmetric_obs: + next_obs = wrapper_torch._jax_to_torch(self.state.obs["state"]) + else: + next_obs = wrapper_torch._jax_to_torch(self.state.obs) + rewards = wrapper_torch._jax_to_torch(self.state.reward) + dones = wrapper_torch._jax_to_torch(self.state.done) + return next_obs, rewards, dones, None + + def render_trajectory(self, trajectory): + scene_option = mujoco.MjvOption() + scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False + scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False + scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False + + frames = self.env.render( + trajectory, + camera="track" if "Joystick" in self.env_name else None, + height=480, + width=640, + scene_option=scene_option, + ) + return frames + + +def make_env( + env_name, + seed, + num_envs, + num_eval_envs, + device_rank, + use_tuned_reward=False, + use_domain_randomization=False, + use_push_randomization=False, +): + # Make training environment + train_env_cfg = registry.get_default_config(env_name) + if use_tuned_reward: + # NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper. + assert env_name in ["G1JoystickRoughTerrain", "G1JoystickFlatTerrain"] + train_env_cfg.reward_config.scales.energy = -5e-5 + train_env_cfg.reward_config.scales.action_rate = -1e-1 + train_env_cfg.reward_config.scales.torques = -1e-3 + train_env_cfg.reward_config.scales.pose = -1.0 + train_env_cfg.reward_config.scales.tracking_ang_vel = 1.25 + train_env_cfg.reward_config.scales.tracking_lin_vel = 1.25 + train_env_cfg.reward_config.scales.feet_phase = 1.0 + train_env_cfg.reward_config.scales.ang_vel_xy = -0.3 + train_env_cfg.reward_config.scales.orientation = -5.0 + + is_humanoid_task = env_name in [ + "G1JoystickRoughTerrain", + "G1JoystickFlatTerrain", + "T1JoystickRoughTerrain", + "T1JoystickFlatTerrain", + ] + + if is_humanoid_task and not use_push_randomization: + train_env_cfg.push_config.enable = False + train_env_cfg.push_config.magnitude_range = [0.0, 0.0] + randomizer = ( + registry.get_domain_randomizer(env_name) if use_domain_randomization else None + ) + raw_env = registry.load(env_name, config=train_env_cfg) + train_env = wrapper_torch.RSLRLBraxWrapper( + raw_env, + num_envs, + seed, + train_env_cfg.episode_length, + train_env_cfg.action_repeat, + randomization_fn=randomizer, + device_rank=device_rank, + ) + + # Make evaluation environment + eval_env_cfg = registry.get_default_config(env_name) + if is_humanoid_task and not use_push_randomization: + eval_env_cfg.push_config.enable = False + eval_env_cfg.push_config.magnitude_range = [0.0, 0.0] + eval_env = registry.load(env_name, config=eval_env_cfg) + eval_env = PlaygroundEvalEnvWrapper( + eval_env, eval_env_cfg.episode_length, env_name, num_eval_envs, seed + ) + + render_env_cfg = registry.get_default_config(env_name) + if is_humanoid_task and not use_push_randomization: + render_env_cfg.push_config.enable = False + render_env_cfg.push_config.magnitude_range = [0.0, 0.0] + render_env = registry.load(env_name, config=render_env_cfg) + render_env = PlaygroundEvalEnvWrapper( + render_env, render_env_cfg.episode_length, env_name, 1, seed + ) + + return train_env, eval_env, render_env diff --git a/fast_td3/fast_td3.py b/fast_td3/fast_td3.py new file mode 100644 index 0000000..2a149bb --- /dev/null +++ b/fast_td3/fast_td3.py @@ -0,0 +1,217 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DistributionalQNetwork(nn.Module): + def __init__( + self, + n_obs: int, + n_act: int, + num_atoms: int, + v_min: float, + v_max: float, + hidden_dim: int, + device: torch.device = None, + ): + super().__init__() + self.net = nn.Sequential( + nn.Linear(n_obs + n_act, hidden_dim, device=device), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim // 2, device=device), + nn.ReLU(), + nn.Linear(hidden_dim // 2, hidden_dim // 4, device=device), + nn.ReLU(), + nn.Linear(hidden_dim // 4, num_atoms, device=device), + ) + self.v_min = v_min + self.v_max = v_max + self.num_atoms = num_atoms + + def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + x = torch.cat([obs, actions], 1) + x = self.net(x) + return x + + def projection( + self, + obs: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + bootstrap: torch.Tensor, + gamma: float, + q_support: torch.Tensor, + device: torch.device, + ) -> torch.Tensor: + delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) + batch_size = rewards.shape[0] + + target_z = rewards.unsqueeze(1) + bootstrap.unsqueeze(1) * gamma * q_support + target_z = target_z.clamp(self.v_min, self.v_max) + b = (target_z - self.v_min) / delta_z + l = torch.floor(b).long() + u = torch.ceil(b).long() + + l_mask = torch.logical_and((u > 0), (l == u)) + u_mask = torch.logical_and((l < (self.num_atoms - 1)), (l == u)) + + l = torch.where(l_mask, l - 1, l) + u = torch.where(u_mask, u + 1, u) + + next_dist = F.softmax(self.forward(obs, actions), dim=1) + proj_dist = torch.zeros_like(next_dist) + offset = ( + torch.linspace( + 0, (batch_size - 1) * self.num_atoms, batch_size, device=device + ) + .unsqueeze(1) + .expand(batch_size, self.num_atoms) + .long() + ) + proj_dist.view(-1).index_add_( + 0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1) + ) + proj_dist.view(-1).index_add_( + 0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1) + ) + return proj_dist + + +class Critic(nn.Module): + def __init__( + self, + n_obs: int, + n_act: int, + num_atoms: int, + v_min: float, + v_max: float, + hidden_dim: int, + device: torch.device = None, + ): + super().__init__() + self.qnet1 = DistributionalQNetwork( + n_obs=n_obs, + n_act=n_act, + num_atoms=num_atoms, + v_min=v_min, + v_max=v_max, + hidden_dim=hidden_dim, + device=device, + ) + self.qnet2 = DistributionalQNetwork( + n_obs=n_obs, + n_act=n_act, + num_atoms=num_atoms, + v_min=v_min, + v_max=v_max, + hidden_dim=hidden_dim, + device=device, + ) + + self.register_buffer( + "q_support", torch.linspace(v_min, v_max, num_atoms, device=device) + ) + + def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + return self.qnet1(obs, actions), self.qnet2(obs, actions) + + def projection( + self, + obs: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + bootstrap: torch.Tensor, + gamma: float, + ) -> torch.Tensor: + """Projection operation that includes q_support directly""" + q1_proj = self.qnet1.projection( + obs, + actions, + rewards, + bootstrap, + gamma, + self.q_support, + self.q_support.device, + ) + q2_proj = self.qnet2.projection( + obs, + actions, + rewards, + bootstrap, + gamma, + self.q_support, + self.q_support.device, + ) + return q1_proj, q2_proj + + def get_value(self, probs: torch.Tensor) -> torch.Tensor: + """Calculate value from logits using support""" + return torch.sum(probs * self.q_support, dim=1) + + +class Actor(nn.Module): + def __init__( + self, + n_obs: int, + n_act: int, + num_envs: int, + init_scale: float, + hidden_dim: int, + std_min: float = 0.05, + std_max: float = 0.8, + device: torch.device = None, + ): + super().__init__() + self.n_act = n_act + self.net = nn.Sequential( + nn.Linear(n_obs, hidden_dim, device=device), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim // 2, device=device), + nn.ReLU(), + nn.Linear(hidden_dim // 2, hidden_dim // 4, device=device), + nn.ReLU(), + ) + self.fc_mu = nn.Sequential( + nn.Linear(hidden_dim // 4, n_act, device=device), + nn.Tanh(), + ) + nn.init.normal_(self.fc_mu[0].weight, 0.0, init_scale) + nn.init.constant_(self.fc_mu[0].bias, 0.0) + + noise_scales = ( + torch.rand(num_envs, 1, device=device) * (std_max - std_min) + std_min + ) + self.register_buffer("noise_scales", noise_scales) + + self.register_buffer("std_min", torch.as_tensor(std_min, device=device)) + self.register_buffer("std_max", torch.as_tensor(std_max, device=device)) + self.n_envs = num_envs + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + x = obs + x = self.net(x) + action = self.fc_mu(x) + return action + + def explore( + self, obs: torch.Tensor, dones: torch.Tensor = None, deterministic: bool = False + ) -> torch.Tensor: + # If dones is provided, resample noise for environments that are done + if dones is not None and dones.sum() > 0: + # Generate new noise scales for done environments (one per environment) + new_scales = ( + torch.rand(self.n_envs, 1, device=obs.device) + * (self.std_max - self.std_min) + + self.std_min + ) + + # Update only the noise scales for environments that are done + dones_view = dones.view(-1, 1) > 0 + self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales) + + act = self(obs) + if deterministic: + return act + + noise = torch.randn_like(act) * self.noise_scales + return act + noise diff --git a/fast_td3/fast_td3_deploy.py b/fast_td3/fast_td3_deploy.py new file mode 100644 index 0000000..15fa73a --- /dev/null +++ b/fast_td3/fast_td3_deploy.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +from .fast_td3_utils import EmpiricalNormalization +from .fast_td3 import Actor + + +class Policy(nn.Module): + def __init__( + self, + n_obs: int, + n_act: int, + num_envs: int, + init_scale: float, + actor_hidden_dim: int, + ): + super().__init__() + self.actor = Actor( + n_obs=n_obs, + n_act=n_act, + num_envs=num_envs, + device="cpu", + init_scale=init_scale, + hidden_dim=actor_hidden_dim, + ) + self.obs_normalizer = EmpiricalNormalization(shape=n_obs, device="cpu") + + self.actor.eval() + self.obs_normalizer.eval() + + @torch.no_grad + def forward(self, obs: torch.Tensor) -> torch.Tensor: + norm_obs = self.obs_normalizer(obs) + actions = self.actor(norm_obs) + return actions + + @torch.no_grad + def act(self, obs: torch.Tensor) -> torch.distributions.Normal: + actions = self.forward(obs) + return torch.distributions.Normal(actions, torch.ones_like(actions) * 1e-8) + + +def load_policy(checkpoint_path): + torch_checkpoint = torch.load( + f"{checkpoint_path}", map_location="cpu", weights_only=False + ) + args = torch_checkpoint["args"] + + n_obs = torch_checkpoint["actor_state_dict"]["net.0.weight"].shape[-1] + n_act = torch_checkpoint["actor_state_dict"]["fc_mu.0.weight"].shape[0] + + policy = Policy( + n_obs=n_obs, + n_act=n_act, + num_envs=args["num_envs"], + init_scale=args["init_scale"], + actor_hidden_dim=args["actor_hidden_dim"], + ) + + policy.actor.load_state_dict(torch_checkpoint["actor_state_dict"]) + + if len(torch_checkpoint["obs_normalizer_state"]) == 0: + policy.obs_normalizer = nn.Identity() + else: + policy.obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"]) + + return policy diff --git a/fast_td3/fast_td3_utils.py b/fast_td3/fast_td3_utils.py new file mode 100644 index 0000000..ac0b44e --- /dev/null +++ b/fast_td3/fast_td3_utils.py @@ -0,0 +1,387 @@ +import os + +import torch +import torch.nn as nn + +from tensordict import TensorDict + + +class SimpleReplayBuffer(nn.Module): + def __init__( + self, + n_env: int, + buffer_size: int, + n_obs: int, + n_act: int, + n_critic_obs: int, + asymmetric_obs: bool = False, + n_steps: int = 1, + gamma: float = 0.99, + device=None, + ): + """ + A simple replay buffer that stores transitions in a circular buffer. + Supports n-step returns and asymmetric observations. + """ + super().__init__() + + self.n_env = n_env + self.buffer_size = buffer_size + self.n_obs = n_obs + self.n_act = n_act + self.n_critic_obs = n_critic_obs + self.asymmetric_obs = asymmetric_obs + self.gamma = gamma + self.n_steps = n_steps + self.device = device + + self.observations = torch.zeros( + (n_env, buffer_size, n_obs), device=device, dtype=torch.float + ) + self.actions = torch.zeros( + (n_env, buffer_size, n_act), device=device, dtype=torch.float + ) + self.rewards = torch.zeros( + (n_env, buffer_size), device=device, dtype=torch.float + ) + self.dones = torch.zeros((n_env, buffer_size), device=device, dtype=torch.long) + self.truncations = torch.zeros( + (n_env, buffer_size), device=device, dtype=torch.long + ) + self.next_observations = torch.zeros( + (n_env, buffer_size, n_obs), device=device, dtype=torch.float + ) + if asymmetric_obs: + self.critic_observations = torch.zeros( + (n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float + ) + self.next_critic_observations = torch.zeros( + (n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float + ) + self.ptr = 0 + + def extend( + self, + tensor_dict: TensorDict, + ): + observations = tensor_dict["observations"] + actions = tensor_dict["actions"] + rewards = tensor_dict["next"]["rewards"] + dones = tensor_dict["next"]["dones"] + truncations = tensor_dict["next"]["truncations"] + next_observations = tensor_dict["next"]["observations"] + + ptr = self.ptr % self.buffer_size + self.observations[:, ptr] = observations + self.actions[:, ptr] = actions + self.rewards[:, ptr] = rewards + self.dones[:, ptr] = dones + self.truncations[:, ptr] = truncations + self.next_observations[:, ptr] = next_observations + if self.asymmetric_obs: + critic_observations = tensor_dict["critic_observations"] + self.critic_observations[:, ptr] = critic_observations + next_critic_observations = tensor_dict["next"]["critic_observations"] + self.next_critic_observations[:, ptr] = next_critic_observations + self.ptr += 1 + + def sample(self, batch_size: int): + # we will sample n_env * batch_size transitions + + if self.n_steps == 1: + indices = torch.randint( + 0, + min(self.buffer_size, self.ptr), + (self.n_env, batch_size), + device=self.device, + ) + obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs) + act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act) + observations = torch.gather(self.observations, 1, obs_indices).reshape( + self.n_env * batch_size, self.n_obs + ) + next_observations = torch.gather( + self.next_observations, 1, obs_indices + ).reshape(self.n_env * batch_size, self.n_obs) + actions = torch.gather(self.actions, 1, act_indices).reshape( + self.n_env * batch_size, self.n_act + ) + + rewards = torch.gather(self.rewards, 1, indices).reshape( + self.n_env * batch_size + ) + dones = torch.gather(self.dones, 1, indices).reshape( + self.n_env * batch_size + ) + truncations = torch.gather(self.truncations, 1, indices).reshape( + self.n_env * batch_size + ) + if self.asymmetric_obs: + critic_obs_indices = indices.unsqueeze(-1).expand( + -1, -1, self.n_critic_obs + ) + critic_observations = torch.gather( + self.critic_observations, 1, critic_obs_indices + ).reshape(self.n_env * batch_size, self.n_critic_obs) + next_critic_observations = torch.gather( + self.next_critic_observations, 1, critic_obs_indices + ).reshape(self.n_env * batch_size, self.n_critic_obs) + else: + # Sample base indices + indices = torch.randint( + 0, + min(self.buffer_size, self.ptr), + (self.n_env, batch_size), + device=self.device, + ) + obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs) + act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act) + + # Get base transitions + observations = torch.gather(self.observations, 1, obs_indices).reshape( + self.n_env * batch_size, self.n_obs + ) + actions = torch.gather(self.actions, 1, act_indices).reshape( + self.n_env * batch_size, self.n_act + ) + if self.asymmetric_obs: + critic_obs_indices = indices.unsqueeze(-1).expand( + -1, -1, self.n_critic_obs + ) + critic_observations = torch.gather( + self.critic_observations, 1, critic_obs_indices + ).reshape(self.n_env * batch_size, self.n_critic_obs) + + # Create sequential indices for each sample + # This creates a [n_env, batch_size, n_step] tensor of indices + seq_offsets = torch.arange(self.n_steps, device=self.device).view(1, 1, -1) + all_indices = ( + indices.unsqueeze(-1) + seq_offsets + ) % self.buffer_size # [n_env, batch_size, n_step] + + # Gather all rewards and terminal flags + # Using advanced indexing - result shapes: [n_env, batch_size, n_step] + all_rewards = torch.gather( + self.rewards.unsqueeze(-1).expand(-1, -1, self.n_steps), 1, all_indices + ) + all_dones = torch.gather( + self.dones.unsqueeze(-1).expand(-1, -1, self.n_steps), 1, all_indices + ) + all_truncations = torch.gather( + self.truncations.unsqueeze(-1).expand(-1, -1, self.n_steps), + 1, + all_indices, + ) + + # Create masks for rewards after first done + # This creates a cumulative product that zeroes out rewards after the first done + done_masks = torch.cumprod( + 1.0 - all_dones, dim=2 + ) # [n_env, batch_size, n_step] + + # Create discount factors + discounts = torch.pow( + self.gamma, torch.arange(self.n_steps, device=self.device) + ) # [n_steps] + + # Apply masks and discounts to rewards + masked_rewards = all_rewards * done_masks # [n_env, batch_size, n_step] + discounted_rewards = masked_rewards * discounts.view( + 1, 1, -1 + ) # [n_env, batch_size, n_step] + + # Sum rewards along the n_step dimension + n_step_rewards = discounted_rewards.sum(dim=2) # [n_env, batch_size] + + # Find index of first done or truncation or last step for each sequence + first_done = torch.argmax( + (all_dones > 0).float(), dim=2 + ) # [n_env, batch_size] + first_trunc = torch.argmax( + (all_truncations > 0).float(), dim=2 + ) # [n_env, batch_size] + + # Handle case where there are no dones or truncations + no_dones = all_dones.sum(dim=2) == 0 + no_truncs = all_truncations.sum(dim=2) == 0 + + # When no dones or truncs, use the last index + first_done = torch.where(no_dones, self.n_steps - 1, first_done) + first_trunc = torch.where(no_truncs, self.n_steps - 1, first_trunc) + + # Take the minimum (first) of done or truncation + final_indices = torch.minimum( + first_done, first_trunc + ) # [n_env, batch_size] + + # Create indices to gather the final next observations + final_next_obs_indices = torch.gather( + all_indices, 2, final_indices.unsqueeze(-1) + ).squeeze( + -1 + ) # [n_env, batch_size] + + # Gather final values + final_next_observations = self.next_observations.gather( + 1, final_next_obs_indices.unsqueeze(-1).expand(-1, -1, self.n_obs) + ) + final_dones = self.dones.gather(1, final_next_obs_indices) + final_truncations = self.truncations.gather(1, final_next_obs_indices) + + if self.asymmetric_obs: + final_next_critic_observations = self.next_critic_observations.gather( + 1, + final_next_obs_indices.unsqueeze(-1).expand( + -1, -1, self.n_critic_obs + ), + ) + # Reshape everything to batch dimension + + rewards = n_step_rewards.reshape(self.n_env * batch_size) + dones = final_dones.reshape(self.n_env * batch_size) + truncations = final_truncations.reshape(self.n_env * batch_size) + next_observations = final_next_observations.reshape( + self.n_env * batch_size, self.n_obs + ) + + if self.asymmetric_obs: + next_critic_observations = final_next_critic_observations.reshape( + self.n_env * batch_size, self.n_critic_obs + ) + + out = TensorDict( + { + "observations": observations, + "actions": actions, + "next": { + "rewards": rewards, + "dones": dones, + "truncations": truncations, + "observations": next_observations, + }, + }, + batch_size=self.n_env * batch_size, + ) + if self.asymmetric_obs: + out["critic_observations"] = critic_observations + out["next"]["critic_observations"] = next_critic_observations + return out + + +class EmpiricalNormalization(nn.Module): + """Normalize mean and variance of values based on empirical values.""" + + def __init__(self, shape, device, eps=1e-2, until=None): + """Initialize EmpiricalNormalization module. + + Args: + shape (int or tuple of int): Shape of input values except batch axis. + eps (float): Small value for stability. + until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes + exceeds it. + """ + super().__init__() + self.eps = eps + self.until = until + self.device = device + self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0).to(device)) + self.register_buffer("_var", torch.ones(shape).unsqueeze(0).to(device)) + self.register_buffer("_std", torch.ones(shape).unsqueeze(0).to(device)) + self.register_buffer("count", torch.tensor(0, dtype=torch.long).to(device)) + + @property + def mean(self): + return self._mean.squeeze(0).clone() + + @property + def std(self): + return self._std.squeeze(0).clone() + + def forward(self, x: torch.Tensor, center: bool = True) -> torch.Tensor: + if x.shape[1:] != self._mean.shape[1:]: + raise ValueError( + f"Expected input of shape (*,{self._mean.shape[1:]}), got {x.shape}" + ) + + if self.training: + self.update(x) + if center: + return (x - self._mean) / (self._std + self.eps) + else: + return x / (self._std + self.eps) + + @torch.jit.unused + def update(self, x): + """Learn input values using Welford's online algorithm""" + if self.until is not None and self.count >= self.until: + return + + batch_size = x.shape[0] + batch_mean = torch.mean(x, dim=0, keepdim=True) + + # Update count + new_count = self.count + batch_size + + # Update mean + delta = batch_mean - self._mean + self._mean += (batch_size / new_count) * delta + + # Update variance using Welford's parallel algorithm + if self.count > 0: # Ensure we're not dividing by zero + # Compute batch variance + batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True) + + # Combine variances using parallel algorithm + delta2 = batch_mean - self._mean + m_a = self._var * self.count + m_b = batch_var * batch_size + M2 = m_a + m_b + (delta2**2) * (self.count * batch_size / new_count) + self._var = M2 / new_count + else: + # For first batch, just use batch variance + self._var = torch.mean((x - self._mean) ** 2, dim=0, keepdim=True) + + self._std = torch.sqrt(self._var) + self.count = new_count + + @torch.jit.unused + def inverse(self, y): + return y * (self._std + self.eps) + self._mean + + +def cpu_state(sd): + # detach & move to host without locking the compute stream + return {k: v.detach().to("cpu", non_blocking=True) for k, v in sd.items()} + + +def save_params( + global_step, + actor, + qnet, + qnet_target, + obs_normalizer, + critic_obs_normalizer, + args, + save_path, +): + """Save model parameters and training configuration to disk.""" + os.makedirs(os.path.dirname(save_path), exist_ok=True) + save_dict = { + "actor_state_dict": cpu_state(actor.state_dict()), + "qnet_state_dict": cpu_state(qnet.state_dict()), + "qnet_target_state_dict": cpu_state(qnet_target.state_dict()), + "obs_normalizer_state": ( + cpu_state(obs_normalizer.state_dict()) + if hasattr(obs_normalizer, "state_dict") + else None + ), + "critic_obs_normalizer_state": ( + cpu_state(critic_obs_normalizer.state_dict()) + if hasattr(critic_obs_normalizer, "state_dict") + else None + ), + "args": vars(args), # Save all arguments + "global_step": global_step, + } + torch.save(save_dict, save_path, _use_new_zipfile_serialization=True) + print(f"Saved parameters and configuration to {save_path}") diff --git a/fast_td3/hyperparams.py b/fast_td3/hyperparams.py new file mode 100644 index 0000000..3f1a8ed --- /dev/null +++ b/fast_td3/hyperparams.py @@ -0,0 +1,454 @@ +import os +from dataclasses import dataclass +import tyro + + +@dataclass +class BaseArgs: + # Default hyperparameters -- specifically for HumanoidBench + # See MuJoCoPlaygroundArgs for default hyperparameters for MuJoCo Playground + # See IsaacLabArgs for default hyperparameters for IsaacLab + env_name: str = "h1hand-stand-v0" + """the id of the environment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + device_rank: int = 0 + """the rank of the device""" + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + project: str = "FastTD3" + """the project name""" + use_wandb: bool = True + """whether to use wandb""" + checkpoint_path: str = None + """the path to the checkpoint file""" + num_envs: int = 128 + """the number of environments to run in parallel""" + num_eval_envs: int = 128 + """the number of evaluation environments to run in parallel (only valid for MuJoCo Playground)""" + total_timesteps: int = 150000 + """total timesteps of the experiments""" + critic_learning_rate: float = 3e-4 + """the learning rate of the critic""" + actor_learning_rate: float = 3e-4 + """the learning rate for the actor""" + buffer_size: int = 1024 * 50 + """the replay memory buffer size""" + num_steps: int = 1 + """the number of steps to use for the multi-step return""" + gamma: float = 0.99 + """the discount factor gamma""" + tau: float = 0.1 + """target smoothing coefficient (default: 0.005)""" + batch_size: int = 32768 + """the batch size of sample from the replay memory""" + policy_noise: float = 0.001 + """the scale of policy noise""" + std_min: float = 0.001 + """the minimum scale of noise""" + std_max: float = 0.4 + """the maximum scale of noise""" + learning_starts: int = 10 + """timestep to start learning""" + policy_frequency: int = 2 + """the frequency of training policy (delayed)""" + noise_clip: float = 0.5 + """noise clip parameter of the Target Policy Smoothing Regularization""" + num_updates: int = 2 + """the number of updates to perform per step""" + init_scale: float = 0.01 + """the scale of the initial parameters""" + num_atoms: int = 101 + """the number of atoms""" + v_min: float = -250.0 + """the minimum value of the support""" + v_max: float = 250.0 + """the maximum value of the support""" + critic_hidden_dim: int = 1024 + """the hidden dimension of the critic network""" + actor_hidden_dim: int = 512 + """the hidden dimension of the actor network""" + use_cdq: bool = True + """whether to use Clipped Double Q-learning""" + measure_burnin: int = 3 + """Number of burn-in iterations for speed measure.""" + eval_interval: int = 5000 + """the interval to evaluate the model""" + render_interval: int = 5000 + """the interval to render the model""" + compile: bool = True + """whether to use torch.compile.""" + obs_normalization: bool = True + """whether to enable observation normalization""" + max_grad_norm: float = 0.0 + """the maximum gradient norm""" + amp: bool = True + """whether to use amp""" + amp_dtype: str = "bf16" + """the dtype of the amp""" + disable_bootstrap: bool = False + """Whether to disable bootstrap in the critic learning""" + + use_domain_randomization: bool = False + """(Playground only) whether to use domain randomization""" + use_push_randomization: bool = False + """(Playground only) whether to use push randomization""" + use_tuned_reward: bool = False + """(Playground only) Use tuned reward for G1""" + action_bounds: float = 1.0 + """(IsaacLab only) the bounds of the action space (-action_bounds, action_bounds)""" + + weight_decay: float = 0.1 + """the weight decay of the optimizer""" + save_interval: int = 5000 + """the interval to save the model""" + + +def get_args(): + """ + Parse command-line arguments and return the appropriate Args instance based on env_name. + """ + # First, parse all arguments using the base Args class + base_args = tyro.cli(BaseArgs) + + # Map environment names to their specific Args classes + # For tasks not here, default hyperparameters are used + # See below links for available task list + # - HumanoidBench (https://arxiv.org/abs/2403.10506) + # - IsaacLab (https://isaac-sim.github.io/IsaacLab/main/source/overview/environments.html) + # - MuJoCo Playground (https://arxiv.org/abs/2502.08844) + env_to_args_class = { + # HumanoidBench + # NOTE: These tasks are not full list of HumanoidBench tasks + "h1hand-reach-v0": H1HandReachArgs, + "h1hand-balance-simple-v0": H1HandBalanceSimpleArgs, + "h1hand-balance-hard-v0": H1HandBalanceHardArgs, + "h1hand-pole-v0": H1HandPoleArgs, + "h1hand-truck-v0": H1HandTruckArgs, + "h1hand-maze-v0": H1HandMazeArgs, + "h1hand-push-v0": H1HandPushArgs, + "h1hand-basketball-v0": H1HandBasketballArgs, + "h1hand-window-v0": H1HandWindowArgs, + "h1hand-package-v0": H1HandPackageArgs, + "h1hand-truck-v0": H1HandTruckArgs, + # MuJoCo Playground + # NOTE: These tasks are not full list of MuJoCo Playground tasks + "G1JoystickFlatTerrain": G1JoystickFlatTerrainArgs, + "G1JoystickRoughTerrain": G1JoystickRoughTerrainArgs, + "T1JoystickFlatTerrain": T1JoystickFlatTerrainArgs, + "T1JoystickRoughTerrain": T1JoystickRoughTerrainArgs, + "LeapCubeReorient": LeapCubeReorientArgs, + "LeapCubeRotateZAxis": LeapCubeRotateZAxisArgs, + "Go1JoystickFlatTerrain": Go1JoystickFlatTerrainArgs, + "Go1JoystickRoughTerrain": Go1JoystickRoughTerrainArgs, + "Go1Getup": Go1GetupArgs, + "CheetahRun": CheetahRunArgs, # NOTE: Example config for DeepMind Control Suite + # IsaacLab + # NOTE: These tasks are not full list of IsaacLab tasks + "Isaac-Lift-Cube-Franka-v0": IsaacLiftCubeFrankaArgs, + "Isaac-Open-Drawer-Franka-v0": IsaacOpenDrawerFrankaArgs, + "Isaac-Velocity-Flat-H1-v0": IsaacVelocityFlatH1Args, + "Isaac-Velocity-Flat-G1-v0": IsaacVelocityFlatG1Args, + "Isaac-Velocity-Rough-H1-v0": IsaacVelocityRoughH1Args, + "Isaac-Velocity-Rough-G1-v0": IsaacVelocityRoughG1Args, + "Isaac-Repose-Cube-Allegro-Direct-v0": IsaacReposeCubeAllegroDirectArgs, + "Isaac-Repose-Cube-Shadow-Direct-v0": IsaacReposeCubeShadowDirectArgs, + } + # If the provided env_name has a specific Args class, use it + if base_args.env_name in env_to_args_class: + specific_args_class = env_to_args_class[base_args.env_name] + # Re-parse with the specific class, maintaining any user overrides + specific_args = tyro.cli(specific_args_class) + return specific_args + + if base_args.env_name.startswith("h1hand-") or base_args.env_name.startswith("h1-"): + # HumanoidBench + specific_args = tyro.cli(HumanoidBenchArgs) + elif base_args.env_name.startswith("Isaac-"): + # IsaacLab + specific_args = tyro.cli(IsaacLabArgs) + else: + # MuJoCo Playground + specific_args = tyro.cli(MuJoCoPlaygroundArgs) + return specific_args + + +@dataclass +class HumanoidBenchArgs(BaseArgs): + # See HumanoidBench (https://arxiv.org/abs/2403.10506) for available task list + total_timesteps: int = 100000 + + +@dataclass +class H1HandReachArgs(HumanoidBenchArgs): + env_name: str = "h1hand-reach-v0" + v_min: float = -2000.0 + v_max: float = 2000.0 + + +@dataclass +class H1HandBalanceSimpleArgs(HumanoidBenchArgs): + env_name: str = "h1hand-balance-simple-v0" + total_timesteps: int = 200000 + + +@dataclass +class H1HandBalanceHardArgs(HumanoidBenchArgs): + env_name: str = "h1hand-balance-hard-v0" + total_timesteps: int = 1000000 + + +@dataclass +class H1HandPoleArgs(HumanoidBenchArgs): + env_name: str = "h1hand-pole-v0" + total_timesteps: int = 150000 + + +@dataclass +class H1HandTruckArgs(HumanoidBenchArgs): + env_name: str = "h1hand-truck-v0" + total_timesteps: int = 500000 + + +@dataclass +class H1HandMazeArgs(HumanoidBenchArgs): + env_name: str = "h1hand-maze-v0" + v_min: float = -1000.0 + v_max: float = 1000.0 + + +@dataclass +class H1HandPushArgs(HumanoidBenchArgs): + env_name: str = "h1hand-push-v0" + v_min: float = -1000.0 + v_max: float = 1000.0 + total_timesteps: int = 1000000 + + +@dataclass +class H1HandBasketballArgs(HumanoidBenchArgs): + env_name: str = "h1hand-basketball-v0" + v_min: float = -2000.0 + v_max: float = 2000.0 + total_timesteps: int = 250000 + + +@dataclass +class H1HandWindowArgs(HumanoidBenchArgs): + env_name: str = "h1hand-window-v0" + total_timesteps: int = 250000 + + +@dataclass +class H1HandPackageArgs(HumanoidBenchArgs): + env_name: str = "h1hand-package-v0" + v_min: float = -10000.0 + v_max: float = 10000.0 + + +@dataclass +class H1HandTruckArgs(HumanoidBenchArgs): + env_name: str = "h1hand-truck-v0" + v_min: float = -1000.0 + v_max: float = 1000.0 + + +@dataclass +class MuJoCoPlaygroundArgs(BaseArgs): + # Default hyperparameters for many of Playground environments + v_min: float = -10.0 + v_max: float = 10.0 + buffer_size: int = 1024 * 10 + num_envs: int = 1024 + num_eval_envs: int = 1024 + gamma: float = 0.97 + + +@dataclass +class G1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs): + env_name: str = "G1JoystickFlatTerrain" + total_timesteps: int = 100000 + + +@dataclass +class G1JoystickRoughTerrainArgs(MuJoCoPlaygroundArgs): + env_name: str = "G1JoystickRoughTerrain" + total_timesteps: int = 100000 + + +@dataclass +class T1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs): + env_name: str = "T1JoystickFlatTerrain" + total_timesteps: int = 100000 + + +@dataclass +class T1JoystickRoughTerrainArgs(MuJoCoPlaygroundArgs): + env_name: str = "T1JoystickRoughTerrain" + total_timesteps: int = 100000 + + +@dataclass +class T1LowDofJoystickFlatTerrainArgs(MuJoCoPlaygroundArgs): + env_name: str = "T1LowDofJoystickFlatTerrain" + total_timesteps: int = 1000000 + + +@dataclass +class T1LowDofJoystickRoughTerrainArgs(MuJoCoPlaygroundArgs): + env_name: str = "T1LowDofJoystickRoughTerrain" + total_timesteps: int = 1000000 + + +@dataclass +class CheetahRunArgs(MuJoCoPlaygroundArgs): + # NOTE: This config will work for most DMC tasks, though we haven't tested DMC extensively. + # Future research can consider using LayerNorm as we find it sometimes works better for DMC tasks. + env_name: str = "CheetahRun" + num_steps: int = 3 + v_min: float = -500.0 + v_max: float = 500.0 + std_min: float = 0.1 + policy_noise: float = 0.1 + + +@dataclass +class Go1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs): + env_name: str = "Go1JoystickFlatTerrain" + total_timesteps: int = 50000 + std_min: float = 0.2 + std_max: float = 0.8 + policy_noise: float = 0.2 + num_updates: int = 8 + + +@dataclass +class Go1JoystickRoughTerrainArgs(MuJoCoPlaygroundArgs): + env_name: str = "Go1JoystickRoughTerrain" + total_timesteps: int = 50000 + std_min: float = 0.2 + std_max: float = 0.8 + policy_noise: float = 0.2 + num_updates: int = 8 + + +@dataclass +class Go1GetupArgs(MuJoCoPlaygroundArgs): + env_name: str = "Go1Getup" + total_timesteps: int = 50000 + std_min: float = 0.2 + std_max: float = 0.8 + policy_noise: float = 0.2 + num_updates: int = 8 + + +@dataclass +class LeapCubeReorientArgs(MuJoCoPlaygroundArgs): + env_name: str = "LeapCubeReorient" + num_steps: int = 3 + policy_noise: float = 0.2 + v_min: float = -50.0 + v_max: float = 50.0 + use_cdq: bool = False + + +@dataclass +class LeapCubeRotateZAxisArgs(MuJoCoPlaygroundArgs): + env_name: str = "LeapCubeRotateZAxis" + num_steps: int = 1 + policy_noise: float = 0.2 + v_min: float = -10.0 + v_max: float = 10.0 + use_cdq: bool = False + + +@dataclass +class IsaacLabArgs(BaseArgs): + v_min: float = -10.0 + v_max: float = 10.0 + buffer_size: int = 1024 * 10 + num_envs: int = 4096 + num_eval_envs: int = 4096 + action_bounds: float = 1.0 + std_max: float = 0.4 + num_atoms: int = 251 + render_interval: int = 0 # IsaacLab does not support rendering in our codebase + total_timesteps: int = 100000 + + +@dataclass +class IsaacLiftCubeFrankaArgs(IsaacLabArgs): + # Value learning is unstable for Lift Cube task Due to brittle reward shaping + # Therefore, we need to disable bootstrap from 'reset_obs' in IsaacLab + # Higher UTD works better for manipulation tasks + env_name: str = "Isaac-Lift-Cube-Franka-v0" + num_updates: int = 8 + v_min: float = -50.0 + v_max: float = 50.0 + std_max: float = 0.8 + num_envs: int = 1024 + num_eval_envs: int = 1024 + action_bounds: float = 3.0 + disable_bootstrap: bool = True + total_timesteps: int = 20000 + + +@dataclass +class IsaacOpenDrawerFrankaArgs(IsaacLabArgs): + # Higher UTD works better for manipulation tasks + env_name: str = "Isaac-Open-Drawer-Franka-v0" + v_min: float = -50.0 + v_max: float = 50.0 + num_updates: int = 8 + action_bounds: float = 3.0 + total_timesteps: int = 20000 + + +@dataclass +class IsaacVelocityFlatH1Args(IsaacLabArgs): + env_name: str = "Isaac-Velocity-Flat-H1-v0" + num_steps: int = 3 + total_timesteps: int = 75000 + + +@dataclass +class IsaacVelocityFlatG1Args(IsaacLabArgs): + env_name: str = "Isaac-Velocity-Flat-G1-v0" + num_steps: int = 3 + total_timesteps: int = 50000 + + +@dataclass +class IsaacVelocityRoughH1Args(IsaacLabArgs): + env_name: str = "Isaac-Velocity-Rough-H1-v0" + num_steps: int = 3 + buffer_size: int = 1024 * 5 # To reduce memory usage + total_timesteps: int = 50000 + + +@dataclass +class IsaacVelocityRoughG1Args(IsaacLabArgs): + env_name: str = "Isaac-Velocity-Rough-G1-v0" + num_steps: int = 3 + buffer_size: int = 1024 * 5 # To reduce memory usage + total_timesteps: int = 50000 + + +@dataclass +class IsaacReposeCubeAllegroDirectArgs(IsaacLabArgs): + env_name: str = "Isaac-Repose-Cube-Allegro-Direct-v0" + total_timesteps: int = 100000 + v_min: float = -500.0 + v_max: float = 500.0 + + +@dataclass +class IsaacReposeCubeShadowDirectArgs(IsaacLabArgs): + env_name: str = "Isaac-Repose-Cube-Shadow-Direct-v0" + total_timesteps: int = 100000 + v_min: float = -500.0 + v_max: float = 500.0 diff --git a/fast_td3/train.py b/fast_td3/train.py new file mode 100644 index 0000000..98a52f2 --- /dev/null +++ b/fast_td3/train.py @@ -0,0 +1,602 @@ +import os +import sys + +os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" +os.environ["OMP_NUM_THREADS"] = "1" +if sys.platform != "darwin": + os.environ["MUJOCO_GL"] = "egl" +else: + os.environ["MUJOCO_GL"] = "glfw" +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" +os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest" + +import random +import time + +import tqdm +import wandb +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.amp import autocast, GradScaler + +from tensordict import TensorDict, from_module + +from fast_td3_utils import EmpiricalNormalization, SimpleReplayBuffer, save_params +from hyperparams import get_args +from fast_td3 import Actor, Critic + +torch.set_float32_matmul_precision("high") + +try: + import jax.numpy as jnp +except ImportError: + pass + + +def main(): + args = get_args() + print(args) + run_name = f"{args.env_name}__{args.exp_name}__{args.seed}" + + amp_enabled = args.amp and args.cuda and torch.cuda.is_available() + amp_device_type = ( + "cuda" + if args.cuda and torch.cuda.is_available() + else "mps" if args.cuda and torch.backends.mps.is_available() else "cpu" + ) + amp_dtype = torch.bfloat16 if args.amp_dtype == "bf16" else torch.float16 + + scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16) + + if args.use_wandb: + wandb.init( + project=args.project, + name=run_name, + config=vars(args), + save_code=True, + ) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + if not args.cuda: + device = torch.device("cpu") + else: + if torch.cuda.is_available(): + device = torch.device(f"cuda:{args.device_rank}") + elif torch.backends.mps.is_available(): + device = torch.device(f"mps:{args.device_rank}") + else: + raise ValueError("No GPU available") + print(f"Using device: {device}") + + if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"): + from environments.humanoid_bench_env import HumanoidBenchEnv + + env_type = "humanoid_bench" + envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device) + eval_envs = envs + render_env = HumanoidBenchEnv( + args.env_name, 1, render_mode="rgb_array", device=device + ) + elif args.env_name.startswith("Isaac-"): + from environments.isaaclab_env import IsaacLabEnv + + env_type = "isaaclab" + envs = IsaacLabEnv( + args.env_name, + device.type, + args.num_envs, + args.seed, + action_bounds=args.action_bounds, + ) + eval_envs = envs + render_env = envs + else: + from environments.mujoco_playground_env import make_env + + # TODO: Check if re-using same envs for eval could reduce memory usage + env_type = "mujoco_playground" + envs, eval_envs, render_env = make_env( + args.env_name, + args.seed, + args.num_envs, + args.num_eval_envs, + args.device_rank, + use_tuned_reward=args.use_tuned_reward, + use_domain_randomization=args.use_domain_randomization, + use_push_randomization=args.use_push_randomization, + ) + + n_act = envs.num_actions + n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0] + if envs.asymmetric_obs: + n_critic_obs = ( + envs.num_privileged_obs + if type(envs.num_privileged_obs) == int + else envs.num_privileged_obs[0] + ) + else: + n_critic_obs = n_obs + action_low, action_high = -1.0, 1.0 + + if args.obs_normalization: + obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device) + critic_obs_normalizer = EmpiricalNormalization( + shape=n_critic_obs, device=device + ) + else: + obs_normalizer = nn.Identity() + critic_obs_normalizer = nn.Identity() + + actor = Actor( + n_obs=n_obs, + n_act=n_act, + num_envs=args.num_envs, + device=device, + init_scale=args.init_scale, + hidden_dim=args.actor_hidden_dim, + ) + actor_detach = Actor( + n_obs=n_obs, + n_act=n_act, + num_envs=args.num_envs, + device=device, + init_scale=args.init_scale, + hidden_dim=args.actor_hidden_dim, + ) + # Copy params to actor_detach without grad + from_module(actor).data.to_module(actor_detach) + policy = actor_detach.explore + + qnet = Critic( + n_obs=n_critic_obs, + n_act=n_act, + num_atoms=args.num_atoms, + v_min=args.v_min, + v_max=args.v_max, + hidden_dim=args.critic_hidden_dim, + device=device, + ) + qnet_target = Critic( + n_obs=n_critic_obs, + n_act=n_act, + num_atoms=args.num_atoms, + v_min=args.v_min, + v_max=args.v_max, + hidden_dim=args.critic_hidden_dim, + device=device, + ) + qnet_target.load_state_dict(qnet.state_dict()) + + q_optimizer = optim.AdamW( + list(qnet.parameters()), + lr=args.critic_learning_rate, + weight_decay=args.weight_decay, + ) + actor_optimizer = optim.AdamW( + list(actor.parameters()), + lr=args.actor_learning_rate, + weight_decay=args.weight_decay, + ) + + rb = SimpleReplayBuffer( + n_env=args.num_envs, + buffer_size=args.buffer_size, + n_obs=n_obs, + n_act=n_act, + n_critic_obs=n_critic_obs, + asymmetric_obs=envs.asymmetric_obs, + n_steps=args.num_steps, + gamma=args.gamma, + device=device, + ) + + policy_noise = args.policy_noise + noise_clip = args.noise_clip + + def evaluate(): + obs_normalizer.eval() + num_eval_envs = eval_envs.num_envs + episode_returns = torch.zeros(num_eval_envs, device=device) + episode_lengths = torch.zeros(num_eval_envs, device=device) + done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device) + + if env_type == "isaaclab": + obs = eval_envs.reset(random_start_init=False) + else: + obs = eval_envs.reset() + + # Run for a fixed number of steps + for _ in range(eval_envs.max_episode_steps): + with torch.no_grad(), autocast( + device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled + ): + obs = normalize_obs(obs) + actions = actor(obs) + + next_obs, rewards, dones, _ = eval_envs.step(actions.float()) + episode_returns = torch.where( + ~done_masks, episode_returns + rewards, episode_returns + ) + episode_lengths = torch.where( + ~done_masks, episode_lengths + 1, episode_lengths + ) + done_masks = torch.logical_or(done_masks, dones) + if done_masks.all(): + break + obs = next_obs + + obs_normalizer.train() + return episode_returns.mean().item(), episode_lengths.mean().item() + + def render_with_rollout(): + obs_normalizer.eval() + + # Quick rollout for rendering + if env_type == "humanoid_bench": + obs = render_env.reset() + renders = [render_env.render()] + elif env_type == "isaaclab": + raise NotImplementedError( + "We don't support rendering for IsaacLab environments" + ) + else: + obs = render_env.reset() + render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]]) + renders = [render_env.state] + for i in range(render_env.max_episode_steps): + with torch.no_grad(), autocast( + device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled + ): + obs = normalize_obs(obs) + actions = actor(obs) + next_obs, _, done, _ = render_env.step(actions.float()) + if env_type == "mujoco_playground": + render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]]) + if i % 2 == 0: + if env_type == "humanoid_bench": + renders.append(render_env.render()) + else: + renders.append(render_env.state) + if done.any(): + break + obs = next_obs + + if env_type == "mujoco_playground": + renders = render_env.render_trajectory(renders) + + obs_normalizer.train() + return renders + + def update_main(data, logs_dict): + with autocast( + device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled + ): + observations = data["observations"] + next_observations = data["next"]["observations"] + if envs.asymmetric_obs: + critic_observations = data["critic_observations"] + next_critic_observations = data["next"]["critic_observations"] + else: + critic_observations = observations + next_critic_observations = next_observations + actions = data["actions"] + rewards = data["next"]["rewards"] + dones = data["next"]["dones"].bool() + truncations = data["next"]["truncations"].bool() + if args.disable_bootstrap: + bootstrap = (~dones).float() + else: + bootstrap = (truncations | ~dones).float() + + clipped_noise = torch.randn_like(actions) + clipped_noise = clipped_noise.mul(policy_noise).clamp( + -noise_clip, noise_clip + ) + + next_state_actions = (actor(next_observations) + clipped_noise).clamp( + action_low, action_high + ) + + with torch.no_grad(): + qf1_next_target_projected, qf2_next_target_projected = ( + qnet_target.projection( + next_critic_observations, + next_state_actions, + rewards, + bootstrap, + args.gamma, + ) + ) + qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected) + qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected) + if args.use_cdq: + qf_next_target_dist = torch.where( + qf1_next_target_value.unsqueeze(1) + < qf2_next_target_value.unsqueeze(1), + qf1_next_target_projected, + qf2_next_target_projected, + ) + qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist + else: + qf1_next_target_dist, qf2_next_target_dist = ( + qf1_next_target_projected, + qf2_next_target_projected, + ) + + qf1, qf2 = qnet(critic_observations, actions) + qf1_loss = -torch.sum( + qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1 + ).mean() + qf2_loss = -torch.sum( + qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1 + ).mean() + qf_loss = qf1_loss + qf2_loss + + q_optimizer.zero_grad(set_to_none=True) + scaler.scale(qf_loss).backward() + scaler.unscale_(q_optimizer) + + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + qnet.parameters(), + max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), + ) + scaler.step(q_optimizer) + scaler.update() + + logs_dict["buffer_rewards"] = rewards.mean() + logs_dict["critic_grad_norm"] = critic_grad_norm.detach() + logs_dict["qf_loss"] = qf_loss.detach() + logs_dict["qf_max"] = qf1_next_target_value.max().detach() + logs_dict["qf_min"] = qf1_next_target_value.min().detach() + return logs_dict + + def update_pol(data, logs_dict): + with autocast( + device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled + ): + critic_observations = ( + data["critic_observations"] + if envs.asymmetric_obs + else data["observations"] + ) + + qf1, qf2 = qnet(critic_observations, actor(data["observations"])) + qf1_value = qnet.get_value(F.softmax(qf1, dim=1)) + qf2_value = qnet.get_value(F.softmax(qf2, dim=1)) + if args.use_cdq: + qf_value = torch.minimum(qf1_value, qf2_value) + else: + qf_value = (qf1_value + qf2_value) / 2.0 + actor_loss = -qf_value.mean() + + actor_optimizer.zero_grad(set_to_none=True) + scaler.scale(actor_loss).backward() + scaler.unscale_(actor_optimizer) + actor_grad_norm = torch.nn.utils.clip_grad_norm_( + actor.parameters(), + max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), + ) + scaler.step(actor_optimizer) + scaler.update() + logs_dict["actor_grad_norm"] = actor_grad_norm.detach() + logs_dict["actor_loss"] = actor_loss.detach() + return logs_dict + + if args.compile: + mode = None + update_main = torch.compile(update_main, mode=mode) + update_pol = torch.compile(update_pol, mode=mode) + policy = torch.compile(policy, mode=mode) + normalize_obs = torch.compile(obs_normalizer.forward, mode=mode) + normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode) + else: + normalize_obs = obs_normalizer.forward + normalize_critic_obs = critic_obs_normalizer.forward + + if envs.asymmetric_obs: + obs, critic_obs = envs.reset_with_critic_obs() + critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float) + else: + obs = envs.reset() + if args.checkpoint_path: + # Load checkpoint if specified + torch_checkpoint = torch.load( + f"{args.checkpoint_path}", map_location=device, weights_only=False + ) + actor.load_state_dict(torch_checkpoint["actor_state_dict"]) + obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"]) + critic_obs_normalizer.load_state_dict( + torch_checkpoint["critic_obs_normalizer_state"] + ) + qnet.load_state_dict(torch_checkpoint["qnet_state_dict"]) + qnet_target.load_state_dict(torch_checkpoint["qnet_target_state_dict"]) + global_step = torch_checkpoint["global_step"] + else: + global_step = 0 + + dones = None + pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step) + start_time = None + desc = "" + + while global_step < args.total_timesteps: + logs_dict = TensorDict() + if ( + start_time is None + and global_step >= args.measure_burnin + args.learning_starts + ): + start_time = time.time() + measure_burnin = global_step + + with torch.no_grad(), autocast( + device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled + ): + norm_obs = normalize_obs(obs) + actions = policy(obs=norm_obs, dones=dones) + + next_obs, rewards, dones, infos = envs.step(actions.float()) + truncations = infos["time_outs"] + + if envs.asymmetric_obs: + next_critic_obs = infos["observations"]["critic"] + + # Compute 'true' next_obs and next_critic_obs for saving + true_next_obs = torch.where( + dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs + ) + if envs.asymmetric_obs: + true_next_critic_obs = torch.where( + dones[:, None] > 0, + infos["observations"]["raw"]["critic_obs"], + next_critic_obs, + ) + transition = TensorDict( + { + "observations": obs, + "actions": torch.as_tensor(actions, device=device, dtype=torch.float), + "next": { + "observations": true_next_obs, + "rewards": torch.as_tensor( + rewards, device=device, dtype=torch.float + ), + "truncations": truncations.long(), + "dones": dones.long(), + }, + }, + batch_size=(envs.num_envs,), + device=device, + ) + if envs.asymmetric_obs: + transition["critic_observations"] = critic_obs + transition["next"]["critic_observations"] = true_next_critic_obs + + obs = next_obs + if envs.asymmetric_obs: + critic_obs = next_critic_obs + + rb.extend(transition) + + batch_size = args.batch_size // args.num_envs + if global_step > args.learning_starts: + for i in range(args.num_updates): + data = rb.sample(batch_size) + data["observations"] = normalize_obs(data["observations"]) + data["next"]["observations"] = normalize_obs( + data["next"]["observations"] + ) + if envs.asymmetric_obs: + data["critic_observations"] = normalize_critic_obs( + data["critic_observations"] + ) + data["next"]["critic_observations"] = normalize_critic_obs( + data["next"]["critic_observations"] + ) + logs_dict = update_main(data, logs_dict) + if args.num_updates > 1: + if i % args.policy_frequency == 1: + logs_dict = update_pol(data, logs_dict) + else: + if global_step % args.policy_frequency == 0: + logs_dict = update_pol(data, logs_dict) + + for param, target_param in zip( + qnet.parameters(), qnet_target.parameters() + ): + target_param.data.copy_( + args.tau * param.data + (1 - args.tau) * target_param.data + ) + + if global_step % 100 == 0 and start_time is not None: + speed = (global_step - measure_burnin) / (time.time() - start_time) + pbar.set_description(f"{speed: 4.4f} sps, " + desc) + with torch.no_grad(): + logs = { + "actor_loss": logs_dict["actor_loss"].mean(), + "qf_loss": logs_dict["qf_loss"].mean(), + "qf_max": logs_dict["qf_max"].mean(), + "qf_min": logs_dict["qf_min"].mean(), + "actor_grad_norm": logs_dict["actor_grad_norm"].mean(), + "critic_grad_norm": logs_dict["critic_grad_norm"].mean(), + "buffer_rewards": logs_dict["buffer_rewards"].mean(), + "env_rewards": rewards.mean(), + } + + if args.eval_interval > 0 and global_step % args.eval_interval == 0: + print(f"Evaluating at global step {global_step}") + eval_avg_return, eval_avg_length = evaluate() + if env_type in ["humanoid_bench", "isaaclab"]: + # NOTE: Hacky way of evaluating performance, but just works + obs = envs.reset() + logs["eval_avg_return"] = eval_avg_return + logs["eval_avg_length"] = eval_avg_length + + if ( + args.render_interval > 0 + and global_step % args.render_interval == 0 + ): + renders = render_with_rollout() + if args.use_wandb: + wandb.log( + { + "render_video": wandb.Video( + np.array(renders).transpose( + 0, 3, 1, 2 + ), # Convert to (T, C, H, W) format + fps=30, + format="gif", + ) + }, + step=global_step, + ) + if args.use_wandb: + wandb.log( + { + "speed": speed, + "frame": global_step * args.num_envs, + **logs, + }, + step=global_step, + ) + + if ( + args.save_interval > 0 + and global_step > 0 + and global_step % args.save_interval == 0 + ): + print(f"Saving model at global step {global_step}") + save_params( + global_step, + actor, + qnet, + qnet_target, + obs_normalizer, + critic_obs_normalizer, + args, + f"models/{run_name}_{global_step}.pt", + ) + + global_step += 1 + pbar.update(1) + + save_params( + global_step, + actor, + qnet, + qnet_target, + obs_normalizer, + critic_obs_normalizer, + args, + f"models/{run_name}_final.pt", + ) + + +if __name__ == "__main__": + main() diff --git a/fast_td3/training_notebook.ipynb b/fast_td3/training_notebook.ipynb new file mode 100644 index 0000000..1653660 --- /dev/null +++ b/fast_td3/training_notebook.ipynb @@ -0,0 +1,785 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FastTD3 Training Notebook\n", + "\n", + "Welcome! This notebook will let you execute a series of code blocks that enables you to experience how FastTD3 works -- each block will import packages, define arguments, create environments, create FastTD3 agent, and train the agent.\n", + "\n", + "This notebook also provide the same functionalities as `train.py` -- you can use this notebook to train your own agents, upload logs to wandb, render rollouts, and fine-tune pre-trained agents with more environment steps!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Set environment variables and import packages\n", + "\n", + "import os\n", + "\n", + "os.environ[\"TORCHDYNAMO_INLINE_INBUILT_NN_MODULES\"] = \"1\"\n", + "os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n", + "if sys.platform != \"darwin\":\n", + " os.environ[\"MUJOCO_GL\"] = \"egl\"\n", + "else:\n", + " os.environ[\"MUJOCO_GL\"] = \"glfw\"\n", + "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n", + "os.environ[\"JAX_DEFAULT_MATMUL_PRECISION\"] = \"highest\"\n", + "\n", + "import random\n", + "import time\n", + "\n", + "import tqdm\n", + "import wandb\n", + "import numpy as np\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from torch.amp import autocast, GradScaler\n", + "from tensordict import TensorDict, from_module\n", + "\n", + "torch.set_float32_matmul_precision(\"high\")\n", + "\n", + "from fast_td3_utils import (\n", + " EmpiricalNormalization,\n", + " SimpleReplayBuffer,\n", + " save_params,\n", + ")\n", + "\n", + "from fast_td3 import Critic, Actor" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Set checkpoint if you want to fine-tune from existing checkpoint\n", + "# e.g., set checkpoint to \"models/h1-walk-v0_notebook_experiment_30000.pt\"\n", + "checkpoint_path = None" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Customize arguments as needed\n", + "# However, IsaacLab may not work in Notebook Setup.\n", + "# We recommend using HumanoidBench or MuJoCo Playground for notebook experiments.\n", + "\n", + "# For quick experiments, let's use a task without dexterous hands\n", + "# But for your research, we recommend using `h1hand` tasks in HumanoidBench!\n", + "from hyperparams import HumanoidBenchArgs\n", + "\n", + "args = HumanoidBenchArgs(\n", + " env_name=\"h1-walk-v0\",\n", + " total_timesteps=20000,\n", + " render_interval=5000,\n", + " eval_interval=5000,\n", + ")\n", + "run_name = f\"{args.env_name}_notebook_experiment\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: GPU-Related Configurations\n", + "\n", + "amp_enabled = args.amp and args.cuda and torch.cuda.is_available()\n", + "amp_device_type = (\n", + " \"cuda\"\n", + " if args.cuda and torch.cuda.is_available()\n", + " else \"mps\" if args.cuda and torch.backends.mps.is_available() else \"cpu\"\n", + ")\n", + "amp_dtype = torch.bfloat16 if args.amp_dtype == \"bf16\" else torch.float16\n", + "\n", + "scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16)\n", + "\n", + "if not args.cuda:\n", + " device = torch.device(\"cpu\")\n", + "else:\n", + " if torch.cuda.is_available():\n", + " device = torch.device(f\"cuda:{args.device_rank}\")\n", + " elif torch.backends.mps.is_available():\n", + " device = torch.device(f\"mps:{args.device_rank}\")\n", + " else:\n", + " raise ValueError(\"No GPU available\")\n", + "print(f\"Using device: {device}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Define Wandb if needed\n", + "\n", + "# Set use_wandb to True if you want to use Wandb\n", + "use_wandb = True\n", + "\n", + "if use_wandb:\n", + " wandb.init(\n", + " project=\"FastTD3\",\n", + " name=run_name,\n", + " config=vars(args),\n", + " save_code=True,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Initialize Environment and Related Variables\n", + "\n", + "if args.env_name.startswith(\"h1hand-\") or args.env_name.startswith(\"h1-\"):\n", + " from environments.humanoid_bench_env import HumanoidBenchEnv\n", + "\n", + " env_type = \"humanoid_bench\"\n", + " envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)\n", + " eval_envs = envs\n", + " render_env = HumanoidBenchEnv(args.env_name, 1, render_mode=\"rgb_array\", device=device)\n", + "elif args.env_name.startswith(\"Isaac-\"):\n", + " from environments.isaaclab_env import IsaacLabEnv\n", + "\n", + " env_type = \"isaaclab\"\n", + " envs = IsaacLabEnv(\n", + " args.env_name,\n", + " device.type,\n", + " args.num_envs,\n", + " args.seed,\n", + " action_bounds=args.action_bounds,\n", + " )\n", + " eval_envs = envs\n", + " render_envs = envs\n", + "else:\n", + " from environments.mujoco_playground_env import make_env\n", + " import jax.numpy as jnp\n", + "\n", + " env_type = \"mujoco_playground\"\n", + " envs, eval_envs, render_env = make_env(\n", + " args.env_name,\n", + " args.seed,\n", + " args.num_envs,\n", + " args.num_eval_envs,\n", + " args.device_rank,\n", + " use_tuned_reward=args.use_tuned_reward,\n", + " use_domain_randomization=args.use_domain_randomization,\n", + " )\n", + "\n", + "n_act = envs.num_actions\n", + "n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]\n", + "if envs.asymmetric_obs:\n", + " n_critic_obs = (\n", + " envs.num_privileged_obs\n", + " if type(envs.num_privileged_obs) == int\n", + " else envs.num_privileged_obs[0]\n", + " )\n", + "else:\n", + " n_critic_obs = n_obs\n", + "action_low, action_high = -1.0, 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Initialize Normalizer, Actor, and Critic\n", + "\n", + "if args.obs_normalization:\n", + " obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device)\n", + " critic_obs_normalizer = EmpiricalNormalization(shape=n_critic_obs, device=device)\n", + "else:\n", + " obs_normalizer = nn.Identity()\n", + " critic_obs_normalizer = nn.Identity()\n", + "\n", + "normalize_obs = obs_normalizer.forward\n", + "normalize_critic_obs = critic_obs_normalizer.forward\n", + "\n", + "# Actor setup\n", + "actor = Actor(\n", + " n_obs=n_obs,\n", + " n_act=n_act,\n", + " num_envs=args.num_envs,\n", + " device=device,\n", + " init_scale=args.init_scale,\n", + " hidden_dim=args.actor_hidden_dim,\n", + ")\n", + "actor_detach = Actor(\n", + " n_obs=n_obs,\n", + " n_act=n_act,\n", + " num_envs=args.num_envs,\n", + " device=device,\n", + " init_scale=args.init_scale,\n", + " hidden_dim=args.actor_hidden_dim,\n", + ")\n", + "# Copy params to actor_detach without grad\n", + "from_module(actor).data.to_module(actor_detach)\n", + "policy = actor_detach.explore\n", + "\n", + "qnet = Critic(\n", + " n_obs=n_critic_obs,\n", + " n_act=n_act,\n", + " num_atoms=args.num_atoms,\n", + " v_min=args.v_min,\n", + " v_max=args.v_max,\n", + " hidden_dim=args.critic_hidden_dim,\n", + " device=device,\n", + ")\n", + "qnet_target = Critic(\n", + " n_obs=n_critic_obs,\n", + " n_act=n_act,\n", + " num_atoms=args.num_atoms,\n", + " v_min=args.v_min,\n", + " v_max=args.v_max,\n", + " hidden_dim=args.critic_hidden_dim,\n", + " device=device,\n", + ")\n", + "qnet_target.load_state_dict(qnet.state_dict())\n", + "\n", + "q_optimizer = optim.AdamW(\n", + " list(qnet.parameters()),\n", + " lr=args.critic_learning_rate,\n", + " weight_decay=args.weight_decay,\n", + ")\n", + "actor_optimizer = optim.AdamW(\n", + " list(actor.parameters()),\n", + " lr=args.actor_learning_rate,\n", + " weight_decay=args.weight_decay,\n", + ")\n", + "\n", + "rb = SimpleReplayBuffer(\n", + " n_env=args.num_envs,\n", + " buffer_size=args.buffer_size,\n", + " n_obs=n_obs,\n", + " n_act=n_act,\n", + " n_critic_obs=n_critic_obs,\n", + " asymmetric_obs=envs.asymmetric_obs,\n", + " n_steps=args.num_steps,\n", + " gamma=args.gamma,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Define Evaluation & Rendering Functions\n", + "\n", + "\n", + "def evaluate():\n", + " obs_normalizer.eval()\n", + " num_eval_envs = eval_envs.num_envs\n", + " episode_returns = torch.zeros(num_eval_envs, device=device)\n", + " episode_lengths = torch.zeros(num_eval_envs, device=device)\n", + " done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device)\n", + "\n", + " if env_type == \"isaaclab\":\n", + " obs = eval_envs.reset(random_start_init=False)\n", + " else:\n", + " obs = eval_envs.reset()\n", + "\n", + " # Run for a fixed number of steps\n", + " for _ in range(eval_envs.max_episode_steps):\n", + " with torch.no_grad(), autocast(\n", + " device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n", + " ):\n", + " obs = normalize_obs(obs)\n", + " actions = actor(obs)\n", + "\n", + " next_obs, rewards, dones, _ = eval_envs.step(actions.float())\n", + " episode_returns = torch.where(\n", + " ~done_masks, episode_returns + rewards, episode_returns\n", + " )\n", + " episode_lengths = torch.where(~done_masks, episode_lengths + 1, episode_lengths)\n", + " done_masks = torch.logical_or(done_masks, dones)\n", + " if done_masks.all():\n", + " break\n", + " obs = next_obs\n", + "\n", + " obs_normalizer.train()\n", + " return episode_returns.mean().item(), episode_lengths.mean().item()\n", + "\n", + "\n", + "def render_with_rollout():\n", + " obs_normalizer.eval()\n", + "\n", + " # Quick rollout for rendering\n", + " if env_type == \"humanoid_bench\":\n", + " obs = render_env.reset()\n", + " renders = [render_env.render()]\n", + " elif env_type == \"isaaclab\":\n", + " raise NotImplementedError(\n", + " \"We don't support rendering for IsaacLab environments\"\n", + " )\n", + " else:\n", + " obs = render_env.reset()\n", + " render_env.state.info[\"command\"] = jnp.array([[1.0, 0.0, 0.0]])\n", + " renders = [render_env.state]\n", + " for i in range(render_env.max_episode_steps):\n", + " with torch.no_grad(), autocast(\n", + " device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n", + " ):\n", + " obs = normalize_obs(obs)\n", + " actions = actor(obs)\n", + " next_obs, _, done, _ = render_env.step(actions.float())\n", + " if env_type == \"mujoco_playground\":\n", + " render_env.state.info[\"command\"] = jnp.array([[1.0, 0.0, 0.0]])\n", + " if i % 2 == 0:\n", + " if env_type == \"humanoid_bench\":\n", + " renders.append(render_env.render())\n", + " else:\n", + " renders.append(render_env.state)\n", + " if done.any():\n", + " break\n", + " obs = next_obs\n", + "\n", + " if env_type == \"mujoco_playground\":\n", + " renders = render_env.render_trajectory(renders)\n", + "\n", + " obs_normalizer.train()\n", + " return renders" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Define Update Functions\n", + "\n", + "policy_noise = args.policy_noise\n", + "noise_clip = args.noise_clip\n", + "\n", + "\n", + "def update_main(data, logs_dict):\n", + " with autocast(device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled):\n", + " observations = data[\"observations\"]\n", + " next_observations = data[\"next\"][\"observations\"]\n", + " if envs.asymmetric_obs:\n", + " critic_observations = data[\"critic_observations\"]\n", + " next_critic_observations = data[\"next\"][\"critic_observations\"]\n", + " else:\n", + " critic_observations = observations\n", + " next_critic_observations = next_observations\n", + " actions = data[\"actions\"]\n", + " rewards = data[\"next\"][\"rewards\"]\n", + " dones = data[\"next\"][\"dones\"].bool()\n", + " truncations = data[\"next\"][\"truncations\"].bool()\n", + " if args.disable_bootstrap:\n", + " bootstrap = (~dones).float()\n", + " else:\n", + " bootstrap = (truncations | ~dones).float()\n", + "\n", + " clipped_noise = torch.randn_like(actions)\n", + " clipped_noise = clipped_noise.mul(policy_noise).clamp(-noise_clip, noise_clip)\n", + "\n", + " next_state_actions = (actor(next_observations) + clipped_noise).clamp(\n", + " action_low, action_high\n", + " )\n", + "\n", + " with torch.no_grad():\n", + " qf1_next_target_projected, qf2_next_target_projected = (\n", + " qnet_target.projection(\n", + " next_critic_observations,\n", + " next_state_actions,\n", + " rewards,\n", + " bootstrap,\n", + " args.gamma,\n", + " )\n", + " )\n", + " qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)\n", + " qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected)\n", + " if args.use_cdq:\n", + " qf_next_target_dist = torch.where(\n", + " qf1_next_target_value.unsqueeze(1)\n", + " < qf2_next_target_value.unsqueeze(1),\n", + " qf1_next_target_projected,\n", + " qf2_next_target_projected,\n", + " )\n", + " qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist\n", + " else:\n", + " qf1_next_target_dist, qf2_next_target_dist = (\n", + " qf1_next_target_projected,\n", + " qf2_next_target_projected,\n", + " )\n", + "\n", + " qf1, qf2 = qnet(critic_observations, actions)\n", + " qf1_loss = -torch.sum(\n", + " qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1\n", + " ).mean()\n", + " qf2_loss = -torch.sum(\n", + " qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1\n", + " ).mean()\n", + " qf_loss = qf1_loss + qf2_loss\n", + "\n", + " q_optimizer.zero_grad(set_to_none=True)\n", + " scaler.scale(qf_loss).backward()\n", + " scaler.unscale_(q_optimizer)\n", + "\n", + " critic_grad_norm = torch.nn.utils.clip_grad_norm_(\n", + " qnet.parameters(),\n", + " max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float(\"inf\"),\n", + " )\n", + " scaler.step(q_optimizer)\n", + " scaler.update()\n", + "\n", + " logs_dict[\"buffer_rewards\"] = rewards.mean()\n", + " logs_dict[\"critic_grad_norm\"] = critic_grad_norm.detach()\n", + " logs_dict[\"qf_loss\"] = qf_loss.detach()\n", + " logs_dict[\"qf_max\"] = qf1_next_target_value.max().detach()\n", + " logs_dict[\"qf_min\"] = qf1_next_target_value.min().detach()\n", + " return logs_dict\n", + "\n", + "\n", + "def update_pol(data, logs_dict):\n", + " with autocast(device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled):\n", + " critic_observations = (\n", + " data[\"critic_observations\"] if envs.asymmetric_obs else data[\"observations\"]\n", + " )\n", + "\n", + " qf1, qf2 = qnet(critic_observations, actor(data[\"observations\"]))\n", + " qf1_value = qnet.get_value(F.softmax(qf1, dim=1))\n", + " qf2_value = qnet.get_value(F.softmax(qf2, dim=1))\n", + " if args.use_cdq:\n", + " qf_value = torch.minimum(qf1_value, qf2_value)\n", + " else:\n", + " qf_value = (qf1_value + qf2_value) / 2.0\n", + " actor_loss = -qf_value.mean()\n", + "\n", + " actor_optimizer.zero_grad(set_to_none=True)\n", + " scaler.scale(actor_loss).backward()\n", + " scaler.unscale_(actor_optimizer)\n", + " actor_grad_norm = torch.nn.utils.clip_grad_norm_(\n", + " actor.parameters(),\n", + " max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float(\"inf\"),\n", + " )\n", + " scaler.step(actor_optimizer)\n", + " scaler.update()\n", + " logs_dict[\"actor_grad_norm\"] = actor_grad_norm.detach()\n", + " logs_dict[\"actor_loss\"] = actor_loss.detach()\n", + " return logs_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Compile Functions if Needed\n", + "\n", + "if args.compile:\n", + " mode = None\n", + " update_main = torch.compile(update_main, mode=mode)\n", + " update_pol = torch.compile(update_pol, mode=mode)\n", + " policy = torch.compile(policy, mode=mode)\n", + " normalize_obs = torch.compile(normalize_obs, mode=mode)\n", + " normalize_critic_obs = torch.compile(normalize_critic_obs, mode=mode)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Load Checkpoint if Needed\n", + "if checkpoint_path is not None:\n", + " torch_checkpoint = torch.load(\n", + " f\"{checkpoint_path}\", map_location=device, weights_only=False\n", + " )\n", + "\n", + " actor.load_state_dict(torch_checkpoint[\"actor_state_dict\"])\n", + " obs_normalizer.load_state_dict(torch_checkpoint[\"obs_normalizer_state\"])\n", + " critic_obs_normalizer.load_state_dict(\n", + " torch_checkpoint[\"critic_obs_normalizer_state\"]\n", + " )\n", + " qnet.load_state_dict(torch_checkpoint[\"qnet_state_dict\"])\n", + " qnet_target.load_state_dict(torch_checkpoint[\"qnet_target_state_dict\"])\n", + " global_step = torch_checkpoint[\"global_step\"]\n", + "else:\n", + " global_step = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Utility functions for displaying videos in notebook\n", + "\n", + "from IPython.display import display, HTML\n", + "import base64\n", + "import imageio\n", + "import tempfile\n", + "import os\n", + "\n", + "\n", + "def frames_to_video_html(frames, fps=30):\n", + " \"\"\"\n", + " Convert a list of numpy arrays to an HTML5 video element.\n", + "\n", + " Args:\n", + " frames (list): List of numpy arrays representing video frames\n", + " fps (int): Frames per second for the video\n", + "\n", + " Returns:\n", + " HTML object containing the video element\n", + " \"\"\"\n", + " # Create a temporary file to store the video\n", + " with tempfile.NamedTemporaryFile(suffix=\".mp4\", delete=False) as temp_file:\n", + " temp_filename = temp_file.name\n", + "\n", + " # Save frames as video\n", + " imageio.mimsave(temp_filename, frames, fps=fps)\n", + "\n", + " # Read the video file and encode it to base64\n", + " with open(temp_filename, \"rb\") as f:\n", + " video_data = f.read()\n", + " video_b64 = base64.b64encode(video_data).decode(\"utf-8\")\n", + "\n", + " # Create HTML video element\n", + " video_html = f\"\"\"\n", + " \n", + " \"\"\"\n", + "\n", + " # Clean up the temporary file\n", + " os.unlink(temp_filename)\n", + "\n", + " return HTML(video_html)\n", + "\n", + "\n", + "def update_video_display(frames, fps=30):\n", + " \"\"\"\n", + " Display video frames as an embedded HTML5 video element.\n", + "\n", + " Args:\n", + " frames (list): List of numpy arrays representing video frames\n", + " fps (int): Frames per second for the video\n", + " \"\"\"\n", + " video_html = frames_to_video_html(frames, fps=fps)\n", + " display(video_html)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Main Training Loop\n", + "\n", + "if envs.asymmetric_obs:\n", + " obs, critic_obs = envs.reset_with_critic_obs()\n", + " critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float)\n", + "else:\n", + " obs = envs.reset()\n", + "pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step)\n", + "\n", + "dones = None\n", + "while global_step < args.total_timesteps:\n", + " logs_dict = TensorDict()\n", + " with torch.no_grad(), autocast(\n", + " device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n", + " ):\n", + " norm_obs = normalize_obs(obs)\n", + " actions = policy(obs=norm_obs, dones=dones)\n", + "\n", + " next_obs, rewards, dones, infos = envs.step(actions.float())\n", + " truncations = infos[\"time_outs\"]\n", + "\n", + " if envs.asymmetric_obs:\n", + " next_critic_obs = infos[\"observations\"][\"critic\"]\n", + "\n", + " # Compute 'true' next_obs and next_critic_obs for saving\n", + " true_next_obs = torch.where(\n", + " dones[:, None] > 0, infos[\"observations\"][\"raw\"][\"obs\"], next_obs\n", + " )\n", + " if envs.asymmetric_obs:\n", + " true_next_critic_obs = torch.where(\n", + " dones[:, None] > 0,\n", + " infos[\"observations\"][\"raw\"][\"critic_obs\"],\n", + " next_critic_obs,\n", + " )\n", + " transition = TensorDict(\n", + " {\n", + " \"observations\": obs,\n", + " \"actions\": torch.as_tensor(actions, device=device, dtype=torch.float),\n", + " \"next\": {\n", + " \"observations\": true_next_obs,\n", + " \"rewards\": torch.as_tensor(rewards, device=device, dtype=torch.float),\n", + " \"truncations\": truncations.long(),\n", + " \"dones\": dones.long(),\n", + " },\n", + " },\n", + " batch_size=(envs.num_envs,),\n", + " device=device,\n", + " )\n", + " if envs.asymmetric_obs:\n", + " transition[\"critic_observations\"] = critic_obs\n", + " transition[\"next\"][\"critic_observations\"] = true_next_critic_obs\n", + "\n", + " obs = next_obs\n", + " if envs.asymmetric_obs:\n", + " critic_obs = next_critic_obs\n", + "\n", + " rb.extend(transition)\n", + "\n", + " batch_size = args.batch_size // args.num_envs\n", + " if global_step > args.learning_starts:\n", + " for i in range(args.num_updates):\n", + " data = rb.sample(batch_size)\n", + " data[\"observations\"] = normalize_obs(data[\"observations\"])\n", + " data[\"next\"][\"observations\"] = normalize_obs(data[\"next\"][\"observations\"])\n", + " if envs.asymmetric_obs:\n", + " data[\"critic_observations\"] = normalize_critic_obs(\n", + " data[\"critic_observations\"]\n", + " )\n", + " data[\"next\"][\"critic_observations\"] = normalize_critic_obs(\n", + " data[\"next\"][\"critic_observations\"]\n", + " )\n", + " logs_dict = update_main(data, logs_dict)\n", + " if args.num_updates > 1:\n", + " if i % args.policy_frequency == 1:\n", + " logs_dict = update_pol(data, logs_dict)\n", + " else:\n", + " if global_step % args.policy_frequency == 0:\n", + " logs_dict = update_pol(data, logs_dict)\n", + "\n", + " for param, target_param in zip(qnet.parameters(), qnet_target.parameters()):\n", + " target_param.data.copy_(\n", + " args.tau * param.data + (1 - args.tau) * target_param.data\n", + " )\n", + "\n", + " if global_step > 0 and global_step % 100 == 0:\n", + " with torch.no_grad():\n", + " logs = {\n", + " \"actor_loss\": logs_dict[\"actor_loss\"].mean(),\n", + " \"qf_loss\": logs_dict[\"qf_loss\"].mean(),\n", + " \"qf_max\": logs_dict[\"qf_max\"].mean(),\n", + " \"qf_min\": logs_dict[\"qf_min\"].mean(),\n", + " \"actor_grad_norm\": logs_dict[\"actor_grad_norm\"].mean(),\n", + " \"critic_grad_norm\": logs_dict[\"critic_grad_norm\"].mean(),\n", + " \"buffer_rewards\": logs_dict[\"buffer_rewards\"].mean(),\n", + " \"env_rewards\": rewards.mean(),\n", + " }\n", + "\n", + " if args.eval_interval > 0 and global_step % args.eval_interval == 0:\n", + " eval_avg_return, eval_avg_length = evaluate()\n", + " if env_type in [\"humanoid_bench\", \"isaaclab\"]:\n", + " # NOTE: Hacky way of evaluating performance, but just works\n", + " obs = envs.reset()\n", + " logs[\"eval_avg_return\"] = eval_avg_return\n", + " logs[\"eval_avg_length\"] = eval_avg_length\n", + "\n", + " if args.render_interval > 0 and global_step % args.render_interval == 0:\n", + " renders = render_with_rollout()\n", + " print_logs = {\n", + " k: v.item() if isinstance(v, torch.Tensor) else v\n", + " for k, v in logs.items()\n", + " }\n", + " for k, v in print_logs.items():\n", + " print(f\"{k}: {v:.4f}\")\n", + " update_video_display(renders, fps=30)\n", + " if use_wandb:\n", + " wandb.log(\n", + " {\n", + " \"render_video\": wandb.Video(\n", + " np.array(renders).transpose(\n", + " 0, 3, 1, 2\n", + " ), # Convert to (T, C, H, W) format\n", + " fps=30,\n", + " format=\"gif\",\n", + " )\n", + " },\n", + " step=global_step,\n", + " )\n", + " if use_wandb:\n", + " wandb.log(\n", + " {\n", + " \"frame\": global_step * args.num_envs,\n", + " **logs,\n", + " },\n", + " step=global_step,\n", + " )\n", + "\n", + " if (\n", + " args.save_interval > 0\n", + " and global_step > 0\n", + " and global_step % args.save_interval == 0\n", + " ):\n", + " save_params(\n", + " global_step,\n", + " actor,\n", + " qnet,\n", + " qnet_target,\n", + " obs_normalizer,\n", + " critic_obs_normalizer,\n", + " args,\n", + " f\"models/{run_name}_{global_step}.pt\",\n", + " )\n", + "\n", + " global_step += 1\n", + " pbar.update(1)\n", + "\n", + "save_params(\n", + " global_step,\n", + " actor,\n", + " qnet,\n", + " qnet_target,\n", + " obs_normalizer,\n", + " critic_obs_normalizer,\n", + " args,\n", + " f\"models/{run_name}_final.pt\",\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fasttd3_hb", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.17" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements/requirements.txt b/requirements/requirements.txt new file mode 100644 index 0000000..9366a6f --- /dev/null +++ b/requirements/requirements.txt @@ -0,0 +1,18 @@ +gymnasium<1.0.0 +jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" +matplotlib +moviepy +numpy<2.0 +pandas +protobuf +pygame +stable-baselines3 +tqdm +wandb +torchrl==0.7.2 +tensordict==0.7.2 +tyro +loguru +torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124 +torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124 +torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 diff --git a/requirements/requirements_playground.txt b/requirements/requirements_playground.txt new file mode 100644 index 0000000..b851ace --- /dev/null +++ b/requirements/requirements_playground.txt @@ -0,0 +1,34 @@ +gymnasium<1.0.0 +jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" +matplotlib +moviepy +numpy<2.0 +pandas +protobuf +pygame +stable-baselines3 +tqdm +wandb +torchrl==0.7.2 +tensordict==0.7.2 +tyro +loguru +git+https://github.com/younggyoseo/mujoco_playground.git +torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124 +torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124 +torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 +jax[cuda12]==0.4.35 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvcc-cu12==12.8.93 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..6402a1e --- /dev/null +++ b/setup.py @@ -0,0 +1,11 @@ +from setuptools import setup, find_packages + +setup( + name="fast_td3", + version="0.1.0", + description="FastTD3 implementation", + author="", + author_email="", + url="", + packages=find_packages(), +) diff --git a/sim2real.md b/sim2real.md new file mode 100644 index 0000000..3e326d0 --- /dev/null +++ b/sim2real.md @@ -0,0 +1,101 @@ +# Guide for Sim2Real Training & Deployment + +This guide provides guide to run sim-to-real experiments using FastTD3 and BoosterGym. + +**⚠️ Warning:** Deploying RL policies to real hardware can be sometimes very dangerous. Please make sure that you understand everything, check your policies work well in simulation, set every robot configuration correct (e.g., damping, stiffness, torque limits, etc), and follow proper safety protocols. **Simply copy-pasting commands in this README is not safe**. + +## ⚙️ Prerequisites + +Install dependencies for Playground experiments (see `README.md`) + +Then, install `fast_td3` package with `pip install -e .` so you can import its classes in BoosterGym (see `fast_td3_deploy.py`). + +**⚠️ Note:** Our sim-to-real experiments depend on our customized MuJoCo Playground that supports `T1LowDimJoystick` tasks for 12-DOF T1 control instead of 23-DOF T1 control in `T1Joystick` tasks. + +## 🚀 Training in simulation + +Users can train deployable policies for Booster T1 with FastTD3 using the below script: + +```bash +python fast_td3/train.py --env_name T1LowDimJoystickRoughTerrain --exp_name FastTD3 --use_domain_randomization --use_push_randomization --total_timesteps 1000000 --render_interval 0 --seed 2 +``` + +**⚠️ Note:** There is no 'guaranteed' number of training steps that can ensure safe real-world deployment. Usually, the gait becomes more stable with longer training. Please check the quality of gaits via sim-to-sim transfer, and fine-tune the policy to fix the issues. Use the checkpoints in `models` directory for sim-to-sim or sim-to-real transfer. + +**⚠️ Note:** We set `render_interval` to 0 to avoid dumping a lot of videos into wandb. Make sure to set it to non-zero values if you want to render videos during training. + + + +### (Optional) 2-Stage Training + +For faster convergence, users can consider introducing curriculum to the training -- so that the robot first learns to walk in a flat terrain without push perturbations. For this, train policies with the below script: + +```bash +STAGE1_STEPS = 100000 +STAGE2_STEPS = 300000 # Effective steps: 300000 - 200000 = 100000 +SEED = 2 +CHECKPOINT_PATH = T1LowDimJoystickFlatTerrain__FastTD3-Stage1__${SEED}_final.pt + +conda activate fasttd3_playground + +# Stage 1 training +python fast_td3/train.py --env_name T1LowDimJoystickFlatTerrain --exp_name FastTD3-Stage1 --use_domain_randomization --no_use_push_randomization --total_timesteps ${STAGE1_STEPS} --render_interval 0 --seed ${SEED} + +# Stage 2 training +python fast_td3/train.py --env_name T1LowDimJoystickRoughTerrain --exp_name FastTD3-Stage2 --use_domain_randomization --use_push_randomization --total_timesteps ${STAGE2_STEPS} --render_interval 0 --checkpoint_path ${CHECKPOINT_PATH} --seed ${SEED} +``` + +Again, 100K and 200K steps do not guarantee safe real-world deployment. Please check the quality of gaits via sim-to-sim transfer, and fine-tune the policy to fix the issues. Use the final checkpoint (`models/T1LowDimJoystickRoughTerrain__FastTD3-Stage2__${SEED}_final.pt`) for sim-to-sim or sim-to-real transfer. + +## 🛝 Deployment with BoosterGym + +We use the customized version of [BoosterGym](https://github.com/BoosterRobotics/booster_gym) for deployment with FastTD3. + +First, clone our fork of BoosterGym. + +```bash +git clone https://github.com/carlosferrazza/booster_gym.git +``` + +Then, follow the [guide](https://github.com/carlosferrazza/booster_gym) to install dependencies for BoosterGym. + +### Sim-to-Sim Transfer + +You can check whether the trained policy transfers to non-MJX version of MuJoCo. +Use the following commands in a machine that supports rendering to test sim-to-sim transfer: + +```bash +cd /booster_gym + +# Activate your BoosterGym virtual environemnt + +# Launch MuJoCo simulation +python play_mujoco.py --task=T1 --checkpoint= +mjpython play_mujoco.py --task=T1 --checkpoint= # for Mac +``` + + + +### Sim-to-Real Transfer + +First, prepare a JIT-scripted checkpoint + +```python +# Python snippets for JIT-scripting checkpoints +import torch +from fast_td3 import load_policy +policy = load_policy() +scripted_policy = torch.jit.script(policy) +scripted_policy.save() +``` + +Then, deploy this JIT-scripted checkpoint by following the guide on [Booster T1 Deployment](https://github.com/carlosferrazza/booster_gym/tree/main/deploy). + + +**⚠️ Warning:** Please double-check every value in robot configuration (`booster_gym/deploy/configs/T1.yaml`) is correctly set! If values for position control such as `damping` or `stiffness` are set differently, your robot may perform dangerous behaviors. + +**⚠️ Warning:** You may want to use different configuration (e.g., `damping` and `stiffness`, etc) for your own experiments. Just make sure to thoroughly test it in simulation and make sure to set the values correctly. + +--- + +🚀 That's it! Hope everything went smoothly, and be aware of your safety. \ No newline at end of file