Bug fix -- MTBench evaluation and missing code (#18)
This PR includes these changes: - Fixing a bug in MTBench evaluation - Add a missing `critic_cls` in `train.py` (resolving an issue https://github.com/younggyoseo/FastTD3/issues/17) - Updating hyperparameters for MTBench
This commit is contained in:
parent
cef44108d8
commit
799624b202
@ -10,7 +10,7 @@ For more information, please see our [project webpage](https://younggyo.me/fast_
|
|||||||
|
|
||||||
|
|
||||||
## ❗ Updates
|
## ❗ Updates
|
||||||
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench)
|
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/).
|
||||||
|
|
||||||
- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases.
|
- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases.
|
||||||
|
|
||||||
@ -218,9 +218,9 @@ python fast_td3/train.py \
|
|||||||
--critic_learning_rate_end 3e-5 \
|
--critic_learning_rate_end 3e-5 \
|
||||||
--actor_learning_rate_end 3e-5 \
|
--actor_learning_rate_end 3e-5 \
|
||||||
--weight_decay 0.0 \
|
--weight_decay 0.0 \
|
||||||
--critic_hidden_dim 512 \
|
--critic_hidden_dim 1024 \
|
||||||
--critic_num_blocks 2 \
|
--critic_num_blocks 2 \
|
||||||
--actor_hidden_dim 256 \
|
--actor_hidden_dim 512 \
|
||||||
--actor_num_blocks 1 \
|
--actor_num_blocks 1 \
|
||||||
--seed 1
|
--seed 1
|
||||||
```
|
```
|
||||||
|
@ -64,6 +64,8 @@ class MTBenchEnv:
|
|||||||
# TODO: Check if we need no_grad and detach here
|
# TODO: Check if we need no_grad and detach here
|
||||||
with torch.no_grad(): # do we need this?
|
with torch.no_grad(): # do we need this?
|
||||||
self.env.reset_idx(torch.arange(self.num_envs, device=self.env.device))
|
self.env.reset_idx(torch.arange(self.num_envs, device=self.env.device))
|
||||||
|
self.env.cumulatives["rewards"][:] = 0
|
||||||
|
self.env.cumulatives["success"][:] = 0
|
||||||
obs_dict = self.env.reset()
|
obs_dict = self.env.reset()
|
||||||
return obs_dict["obs"].detach()
|
return obs_dict["obs"].detach()
|
||||||
|
|
||||||
|
@ -297,7 +297,7 @@ class MTBenchArgs(BaseArgs):
|
|||||||
buffer_size: int = 2048 # 2K is usually enough for MTBench
|
buffer_size: int = 2048 # 2K is usually enough for MTBench
|
||||||
num_envs: int = 4096
|
num_envs: int = 4096
|
||||||
num_eval_envs: int = 4096
|
num_eval_envs: int = 4096
|
||||||
gamma: float = 0.99
|
gamma: float = 0.97
|
||||||
num_steps: int = 8
|
num_steps: int = 8
|
||||||
|
|
||||||
|
|
||||||
@ -308,6 +308,7 @@ class MetaWorldMT10Args(MTBenchArgs):
|
|||||||
num_envs: int = 4096
|
num_envs: int = 4096
|
||||||
num_eval_envs: int = 4096
|
num_eval_envs: int = 4096
|
||||||
num_steps: int = 8
|
num_steps: int = 8
|
||||||
|
gamma: float = 0.97
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -318,6 +319,7 @@ class MetaWorldMT50Args(MTBenchArgs):
|
|||||||
num_envs: int = 8192
|
num_envs: int = 8192
|
||||||
num_eval_envs: int = 8192
|
num_eval_envs: int = 8192
|
||||||
num_steps: int = 8
|
num_steps: int = 8
|
||||||
|
gamma: float = 0.99
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -221,6 +221,7 @@ def main():
|
|||||||
from fast_td3_simbav2 import Actor, Critic
|
from fast_td3_simbav2 import Actor, Critic
|
||||||
|
|
||||||
actor_cls = Actor
|
actor_cls = Actor
|
||||||
|
critic_cls = Critic
|
||||||
|
|
||||||
print("Using FastTD3 + SimbaV2")
|
print("Using FastTD3 + SimbaV2")
|
||||||
actor_kwargs.pop("init_scale")
|
actor_kwargs.pop("init_scale")
|
||||||
|
@ -153,7 +153,9 @@
|
|||||||
" env_type = \"humanoid_bench\"\n",
|
" env_type = \"humanoid_bench\"\n",
|
||||||
" envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)\n",
|
" envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)\n",
|
||||||
" eval_envs = envs\n",
|
" eval_envs = envs\n",
|
||||||
" render_env = HumanoidBenchEnv(args.env_name, 1, render_mode=\"rgb_array\", device=device)\n",
|
" render_env = HumanoidBenchEnv(\n",
|
||||||
|
" args.env_name, 1, render_mode=\"rgb_array\", device=device\n",
|
||||||
|
" )\n",
|
||||||
"elif args.env_name.startswith(\"Isaac-\"):\n",
|
"elif args.env_name.startswith(\"Isaac-\"):\n",
|
||||||
" from environments.isaaclab_env import IsaacLabEnv\n",
|
" from environments.isaaclab_env import IsaacLabEnv\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user