From 799624b202d29d5a4afcf4d3c79cd0b666a44a7e Mon Sep 17 00:00:00 2001 From: Younggyo Seo Date: Wed, 25 Jun 2025 09:21:04 -0700 Subject: [PATCH] Bug fix -- MTBench evaluation and missing code (#18) This PR includes these changes: - Fixing a bug in MTBench evaluation - Add a missing `critic_cls` in `train.py` (resolving an issue https://github.com/younggyoseo/FastTD3/issues/17) - Updating hyperparameters for MTBench --- README.md | 6 +++--- fast_td3/environments/mtbench_env.py | 2 ++ fast_td3/hyperparams.py | 4 +++- fast_td3/train.py | 1 + fast_td3/training_notebook.ipynb | 4 +++- 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 78cd5d2..f5993b8 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ For more information, please see our [project webpage](https://younggyo.me/fast_ ## ❗ Updates -- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) +- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/). - **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases. @@ -218,9 +218,9 @@ python fast_td3/train.py \ --critic_learning_rate_end 3e-5 \ --actor_learning_rate_end 3e-5 \ --weight_decay 0.0 \ - --critic_hidden_dim 512 \ + --critic_hidden_dim 1024 \ --critic_num_blocks 2 \ - --actor_hidden_dim 256 \ + --actor_hidden_dim 512 \ --actor_num_blocks 1 \ --seed 1 ``` diff --git a/fast_td3/environments/mtbench_env.py b/fast_td3/environments/mtbench_env.py index 5e04310..67c2eef 100644 --- a/fast_td3/environments/mtbench_env.py +++ b/fast_td3/environments/mtbench_env.py @@ -64,6 +64,8 @@ class MTBenchEnv: # TODO: Check if we need no_grad and detach here with torch.no_grad(): # do we need this? self.env.reset_idx(torch.arange(self.num_envs, device=self.env.device)) + self.env.cumulatives["rewards"][:] = 0 + self.env.cumulatives["success"][:] = 0 obs_dict = self.env.reset() return obs_dict["obs"].detach() diff --git a/fast_td3/hyperparams.py b/fast_td3/hyperparams.py index c42bef1..04c2d82 100644 --- a/fast_td3/hyperparams.py +++ b/fast_td3/hyperparams.py @@ -297,7 +297,7 @@ class MTBenchArgs(BaseArgs): buffer_size: int = 2048 # 2K is usually enough for MTBench num_envs: int = 4096 num_eval_envs: int = 4096 - gamma: float = 0.99 + gamma: float = 0.97 num_steps: int = 8 @@ -308,6 +308,7 @@ class MetaWorldMT10Args(MTBenchArgs): num_envs: int = 4096 num_eval_envs: int = 4096 num_steps: int = 8 + gamma: float = 0.97 @dataclass @@ -318,6 +319,7 @@ class MetaWorldMT50Args(MTBenchArgs): num_envs: int = 8192 num_eval_envs: int = 8192 num_steps: int = 8 + gamma: float = 0.99 @dataclass diff --git a/fast_td3/train.py b/fast_td3/train.py index e568846..e20d057 100644 --- a/fast_td3/train.py +++ b/fast_td3/train.py @@ -221,6 +221,7 @@ def main(): from fast_td3_simbav2 import Actor, Critic actor_cls = Actor + critic_cls = Critic print("Using FastTD3 + SimbaV2") actor_kwargs.pop("init_scale") diff --git a/fast_td3/training_notebook.ipynb b/fast_td3/training_notebook.ipynb index ec256ab..308fcad 100644 --- a/fast_td3/training_notebook.ipynb +++ b/fast_td3/training_notebook.ipynb @@ -153,7 +153,9 @@ " env_type = \"humanoid_bench\"\n", " envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)\n", " eval_envs = envs\n", - " render_env = HumanoidBenchEnv(args.env_name, 1, render_mode=\"rgb_array\", device=device)\n", + " render_env = HumanoidBenchEnv(\n", + " args.env_name, 1, render_mode=\"rgb_array\", device=device\n", + " )\n", "elif args.env_name.startswith(\"Isaac-\"):\n", " from environments.isaaclab_env import IsaacLabEnv\n", "\n",