Bug fix -- MTBench evaluation and missing code (#18)

This PR includes these changes:
- Fixing a bug in MTBench evaluation
- Add a missing `critic_cls` in `train.py` (resolving an issue https://github.com/younggyoseo/FastTD3/issues/17)
- Updating hyperparameters for MTBench
This commit is contained in:
Younggyo Seo 2025-06-25 09:21:04 -07:00 committed by GitHub
parent cef44108d8
commit 799624b202
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 12 additions and 5 deletions

View File

@ -10,7 +10,7 @@ For more information, please see our [project webpage](https://younggyo.me/fast_
## ❗ Updates
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench)
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/).
- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases.
@ -218,9 +218,9 @@ python fast_td3/train.py \
--critic_learning_rate_end 3e-5 \
--actor_learning_rate_end 3e-5 \
--weight_decay 0.0 \
--critic_hidden_dim 512 \
--critic_hidden_dim 1024 \
--critic_num_blocks 2 \
--actor_hidden_dim 256 \
--actor_hidden_dim 512 \
--actor_num_blocks 1 \
--seed 1
```

View File

@ -64,6 +64,8 @@ class MTBenchEnv:
# TODO: Check if we need no_grad and detach here
with torch.no_grad(): # do we need this?
self.env.reset_idx(torch.arange(self.num_envs, device=self.env.device))
self.env.cumulatives["rewards"][:] = 0
self.env.cumulatives["success"][:] = 0
obs_dict = self.env.reset()
return obs_dict["obs"].detach()

View File

@ -297,7 +297,7 @@ class MTBenchArgs(BaseArgs):
buffer_size: int = 2048 # 2K is usually enough for MTBench
num_envs: int = 4096
num_eval_envs: int = 4096
gamma: float = 0.99
gamma: float = 0.97
num_steps: int = 8
@ -308,6 +308,7 @@ class MetaWorldMT10Args(MTBenchArgs):
num_envs: int = 4096
num_eval_envs: int = 4096
num_steps: int = 8
gamma: float = 0.97
@dataclass
@ -318,6 +319,7 @@ class MetaWorldMT50Args(MTBenchArgs):
num_envs: int = 8192
num_eval_envs: int = 8192
num_steps: int = 8
gamma: float = 0.99
@dataclass

View File

@ -221,6 +221,7 @@ def main():
from fast_td3_simbav2 import Actor, Critic
actor_cls = Actor
critic_cls = Critic
print("Using FastTD3 + SimbaV2")
actor_kwargs.pop("init_scale")

View File

@ -153,7 +153,9 @@
" env_type = \"humanoid_bench\"\n",
" envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)\n",
" eval_envs = envs\n",
" render_env = HumanoidBenchEnv(args.env_name, 1, render_mode=\"rgb_array\", device=device)\n",
" render_env = HumanoidBenchEnv(\n",
" args.env_name, 1, render_mode=\"rgb_array\", device=device\n",
" )\n",
"elif args.env_name.startswith(\"Isaac-\"):\n",
" from environments.isaaclab_env import IsaacLabEnv\n",
"\n",