Bug fix -- MTBench evaluation and missing code (#18)
This PR includes these changes: - Fixing a bug in MTBench evaluation - Add a missing `critic_cls` in `train.py` (resolving an issue https://github.com/younggyoseo/FastTD3/issues/17) - Updating hyperparameters for MTBench
This commit is contained in:
parent
cef44108d8
commit
799624b202
@ -10,7 +10,7 @@ For more information, please see our [project webpage](https://younggyo.me/fast_
|
||||
|
||||
|
||||
## ❗ Updates
|
||||
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench)
|
||||
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/).
|
||||
|
||||
- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases.
|
||||
|
||||
@ -218,9 +218,9 @@ python fast_td3/train.py \
|
||||
--critic_learning_rate_end 3e-5 \
|
||||
--actor_learning_rate_end 3e-5 \
|
||||
--weight_decay 0.0 \
|
||||
--critic_hidden_dim 512 \
|
||||
--critic_hidden_dim 1024 \
|
||||
--critic_num_blocks 2 \
|
||||
--actor_hidden_dim 256 \
|
||||
--actor_hidden_dim 512 \
|
||||
--actor_num_blocks 1 \
|
||||
--seed 1
|
||||
```
|
||||
|
@ -64,6 +64,8 @@ class MTBenchEnv:
|
||||
# TODO: Check if we need no_grad and detach here
|
||||
with torch.no_grad(): # do we need this?
|
||||
self.env.reset_idx(torch.arange(self.num_envs, device=self.env.device))
|
||||
self.env.cumulatives["rewards"][:] = 0
|
||||
self.env.cumulatives["success"][:] = 0
|
||||
obs_dict = self.env.reset()
|
||||
return obs_dict["obs"].detach()
|
||||
|
||||
|
@ -297,7 +297,7 @@ class MTBenchArgs(BaseArgs):
|
||||
buffer_size: int = 2048 # 2K is usually enough for MTBench
|
||||
num_envs: int = 4096
|
||||
num_eval_envs: int = 4096
|
||||
gamma: float = 0.99
|
||||
gamma: float = 0.97
|
||||
num_steps: int = 8
|
||||
|
||||
|
||||
@ -308,6 +308,7 @@ class MetaWorldMT10Args(MTBenchArgs):
|
||||
num_envs: int = 4096
|
||||
num_eval_envs: int = 4096
|
||||
num_steps: int = 8
|
||||
gamma: float = 0.97
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -318,6 +319,7 @@ class MetaWorldMT50Args(MTBenchArgs):
|
||||
num_envs: int = 8192
|
||||
num_eval_envs: int = 8192
|
||||
num_steps: int = 8
|
||||
gamma: float = 0.99
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -221,6 +221,7 @@ def main():
|
||||
from fast_td3_simbav2 import Actor, Critic
|
||||
|
||||
actor_cls = Actor
|
||||
critic_cls = Critic
|
||||
|
||||
print("Using FastTD3 + SimbaV2")
|
||||
actor_kwargs.pop("init_scale")
|
||||
|
@ -153,7 +153,9 @@
|
||||
" env_type = \"humanoid_bench\"\n",
|
||||
" envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)\n",
|
||||
" eval_envs = envs\n",
|
||||
" render_env = HumanoidBenchEnv(args.env_name, 1, render_mode=\"rgb_array\", device=device)\n",
|
||||
" render_env = HumanoidBenchEnv(\n",
|
||||
" args.env_name, 1, render_mode=\"rgb_array\", device=device\n",
|
||||
" )\n",
|
||||
"elif args.env_name.startswith(\"Isaac-\"):\n",
|
||||
" from environments.isaaclab_env import IsaacLabEnv\n",
|
||||
"\n",
|
||||
|
Loading…
Reference in New Issue
Block a user