FastTD3/fast_td3/hyperparams.py
Younggyo Seo 51c55d4a8a
Support Multi-GPU Training (#22)
- Change in isaaclab_env wrapper to explicitly state GPU for each simulation
- Removing jax cache to support multi-gpu environment launch in MuJoCo Playground
- Removing .train() and .eval() in evaluation and rendering to avoid deadlock in multi-gpu training
- Supporting synchronous normalization for multi-gpu training
2025-07-07 10:24:42 -07:00

522 lines
16 KiB
Python

import os
from dataclasses import dataclass
import tyro
@dataclass
class BaseArgs:
# Default hyperparameters -- specifically for HumanoidBench
# See MuJoCoPlaygroundArgs for default hyperparameters for MuJoCo Playground
# See IsaacLabArgs for default hyperparameters for IsaacLab
env_name: str = "h1hand-stand-v0"
"""the id of the environment"""
agent: str = "fasttd3"
"""the agent to use: currently support [fasttd3, fasttd3_simbav2]"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
device_rank: int = 0
"""the rank of the device"""
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
project: str = "FastTD3"
"""the project name"""
use_wandb: bool = True
"""whether to use wandb"""
checkpoint_path: str = None
"""the path to the checkpoint file"""
num_envs: int = 128
"""the number of environments to run in parallel"""
num_eval_envs: int = 128
"""the number of evaluation environments to run in parallel (only valid for MuJoCo Playground)"""
total_timesteps: int = 150000
"""total timesteps of the experiments"""
critic_learning_rate: float = 3e-4
"""the learning rate of the critic"""
actor_learning_rate: float = 3e-4
"""the learning rate for the actor"""
critic_learning_rate_end: float = 3e-4
"""the learning rate of the critic at the end of training"""
actor_learning_rate_end: float = 3e-4
"""the learning rate for the actor at the end of training"""
buffer_size: int = 1024 * 50
"""the replay memory buffer size"""
num_steps: int = 1
"""the number of steps to use for the multi-step return"""
gamma: float = 0.99
"""the discount factor gamma"""
tau: float = 0.1
"""target smoothing coefficient (default: 0.005)"""
batch_size: int = 32768
"""the batch size of sample from the replay memory"""
policy_noise: float = 0.001
"""the scale of policy noise"""
std_min: float = 0.001
"""the minimum scale of noise"""
std_max: float = 0.4
"""the maximum scale of noise"""
learning_starts: int = 10
"""timestep to start learning"""
policy_frequency: int = 2
"""the frequency of training policy (delayed)"""
noise_clip: float = 0.5
"""noise clip parameter of the Target Policy Smoothing Regularization"""
num_updates: int = 2
"""the number of updates to perform per step"""
init_scale: float = 0.01
"""the scale of the initial parameters"""
num_atoms: int = 101
"""the number of atoms"""
v_min: float = -250.0
"""the minimum value of the support"""
v_max: float = 250.0
"""the maximum value of the support"""
critic_hidden_dim: int = 1024
"""the hidden dimension of the critic network"""
actor_hidden_dim: int = 512
"""the hidden dimension of the actor network"""
critic_num_blocks: int = 2
"""(SimbaV2 only) the number of blocks in the critic network"""
actor_num_blocks: int = 1
"""(SimbaV2 only) the number of blocks in the actor network"""
use_cdq: bool = True
"""whether to use Clipped Double Q-learning"""
measure_burnin: int = 3
"""Number of burn-in iterations for speed measure."""
eval_interval: int = 5000
"""the interval to evaluate the model"""
render_interval: int = 5000
"""the interval to render the model"""
compile: bool = True
"""whether to use torch.compile."""
compile_mode: str = "reduce-overhead"
"""the mode of torch.compile."""
obs_normalization: bool = True
"""whether to enable observation normalization"""
reward_normalization: bool = False
"""whether to enable reward normalization"""
use_grad_norm_clipping: bool = False
"""whether to use gradient norm clipping."""
max_grad_norm: float = 0.0
"""the maximum gradient norm"""
amp: bool = True
"""whether to use amp"""
amp_dtype: str = "bf16"
"""the dtype of the amp"""
disable_bootstrap: bool = False
"""Whether to disable bootstrap in the critic learning"""
use_domain_randomization: bool = False
"""(Playground only) whether to use domain randomization"""
use_push_randomization: bool = False
"""(Playground only) whether to use push randomization"""
use_tuned_reward: bool = False
"""(Playground only) Use tuned reward for G1"""
action_bounds: float = 1.0
"""(IsaacLab only) the bounds of the action space (-action_bounds, action_bounds)"""
task_embedding_dim: int = 32
"""the dimension of the task embedding"""
weight_decay: float = 0.1
"""the weight decay of the optimizer"""
save_interval: int = 5000
"""the interval to save the model"""
def get_args():
"""
Parse command-line arguments and return the appropriate Args instance based on env_name.
"""
# First, parse all arguments using the base Args class
base_args = tyro.cli(BaseArgs)
# Map environment names to their specific Args classes
# For tasks not here, default hyperparameters are used
# See below links for available task list
# - HumanoidBench (https://arxiv.org/abs/2403.10506)
# - IsaacLab (https://isaac-sim.github.io/IsaacLab/main/source/overview/environments.html)
# - MuJoCo Playground (https://arxiv.org/abs/2502.08844)
env_to_args_class = {
# HumanoidBench
# NOTE: These tasks are not full list of HumanoidBench tasks
"h1hand-reach-v0": H1HandReachArgs,
"h1hand-balance-simple-v0": H1HandBalanceSimpleArgs,
"h1hand-balance-hard-v0": H1HandBalanceHardArgs,
"h1hand-pole-v0": H1HandPoleArgs,
"h1hand-truck-v0": H1HandTruckArgs,
"h1hand-maze-v0": H1HandMazeArgs,
"h1hand-push-v0": H1HandPushArgs,
"h1hand-basketball-v0": H1HandBasketballArgs,
"h1hand-window-v0": H1HandWindowArgs,
"h1hand-package-v0": H1HandPackageArgs,
"h1hand-truck-v0": H1HandTruckArgs,
# MuJoCo Playground
# NOTE: These tasks are not full list of MuJoCo Playground tasks
"G1JoystickFlatTerrain": G1JoystickFlatTerrainArgs,
"G1JoystickRoughTerrain": G1JoystickRoughTerrainArgs,
"T1JoystickFlatTerrain": T1JoystickFlatTerrainArgs,
"T1JoystickRoughTerrain": T1JoystickRoughTerrainArgs,
"LeapCubeReorient": LeapCubeReorientArgs,
"LeapCubeRotateZAxis": LeapCubeRotateZAxisArgs,
"Go1JoystickFlatTerrain": Go1JoystickFlatTerrainArgs,
"Go1JoystickRoughTerrain": Go1JoystickRoughTerrainArgs,
"Go1Getup": Go1GetupArgs,
"CheetahRun": CheetahRunArgs, # NOTE: Example config for DeepMind Control Suite
# IsaacLab
# NOTE: These tasks are not full list of IsaacLab tasks
"Isaac-Lift-Cube-Franka-v0": IsaacLiftCubeFrankaArgs,
"Isaac-Open-Drawer-Franka-v0": IsaacOpenDrawerFrankaArgs,
"Isaac-Velocity-Flat-H1-v0": IsaacVelocityFlatH1Args,
"Isaac-Velocity-Flat-G1-v0": IsaacVelocityFlatG1Args,
"Isaac-Velocity-Rough-H1-v0": IsaacVelocityRoughH1Args,
"Isaac-Velocity-Rough-G1-v0": IsaacVelocityRoughG1Args,
"Isaac-Repose-Cube-Allegro-Direct-v0": IsaacReposeCubeAllegroDirectArgs,
"Isaac-Repose-Cube-Shadow-Direct-v0": IsaacReposeCubeShadowDirectArgs,
# MTBench
"MTBench-meta-world-v2-mt10": MetaWorldMT10Args,
"MTBench-meta-world-v2-mt50": MetaWorldMT50Args,
}
# If the provided env_name has a specific Args class, use it
if base_args.env_name in env_to_args_class:
specific_args_class = env_to_args_class[base_args.env_name]
# Re-parse with the specific class, maintaining any user overrides
specific_args = tyro.cli(specific_args_class)
return specific_args
if base_args.env_name.startswith("h1hand-") or base_args.env_name.startswith("h1-"):
# HumanoidBench
specific_args = tyro.cli(HumanoidBenchArgs)
elif base_args.env_name.startswith("Isaac-"):
# IsaacLab
specific_args = tyro.cli(IsaacLabArgs)
elif base_args.env_name.startswith("MTBench-"):
# MTBench
specific_args = tyro.cli(MTBenchArgs)
else:
# MuJoCo Playground
specific_args = tyro.cli(MuJoCoPlaygroundArgs)
return specific_args
@dataclass
class HumanoidBenchArgs(BaseArgs):
# See HumanoidBench (https://arxiv.org/abs/2403.10506) for available task list
total_timesteps: int = 100000
@dataclass
class H1HandReachArgs(HumanoidBenchArgs):
env_name: str = "h1hand-reach-v0"
v_min: float = -2000.0
v_max: float = 2000.0
@dataclass
class H1HandBalanceSimpleArgs(HumanoidBenchArgs):
env_name: str = "h1hand-balance-simple-v0"
total_timesteps: int = 200000
@dataclass
class H1HandBalanceHardArgs(HumanoidBenchArgs):
env_name: str = "h1hand-balance-hard-v0"
total_timesteps: int = 1000000
@dataclass
class H1HandPoleArgs(HumanoidBenchArgs):
env_name: str = "h1hand-pole-v0"
total_timesteps: int = 150000
@dataclass
class H1HandTruckArgs(HumanoidBenchArgs):
env_name: str = "h1hand-truck-v0"
total_timesteps: int = 500000
@dataclass
class H1HandMazeArgs(HumanoidBenchArgs):
env_name: str = "h1hand-maze-v0"
v_min: float = -1000.0
v_max: float = 1000.0
@dataclass
class H1HandPushArgs(HumanoidBenchArgs):
env_name: str = "h1hand-push-v0"
v_min: float = -1000.0
v_max: float = 1000.0
total_timesteps: int = 1000000
@dataclass
class H1HandBasketballArgs(HumanoidBenchArgs):
env_name: str = "h1hand-basketball-v0"
v_min: float = -2000.0
v_max: float = 2000.0
total_timesteps: int = 250000
@dataclass
class H1HandWindowArgs(HumanoidBenchArgs):
env_name: str = "h1hand-window-v0"
total_timesteps: int = 250000
@dataclass
class H1HandPackageArgs(HumanoidBenchArgs):
env_name: str = "h1hand-package-v0"
v_min: float = -10000.0
v_max: float = 10000.0
@dataclass
class H1HandTruckArgs(HumanoidBenchArgs):
env_name: str = "h1hand-truck-v0"
v_min: float = -1000.0
v_max: float = 1000.0
@dataclass
class MuJoCoPlaygroundArgs(BaseArgs):
# Default hyperparameters for many of Playground environments
v_min: float = -10.0
v_max: float = 10.0
buffer_size: int = 1024 * 10
num_envs: int = 1024
num_eval_envs: int = 1024
gamma: float = 0.97
@dataclass
class MTBenchArgs(BaseArgs):
# Default hyperparameters for MTBench
reward_normalization: bool = True
v_min: float = -10.0
v_max: float = 10.0
buffer_size: int = 2048 # 2K is usually enough for MTBench
num_envs: int = 4096
num_eval_envs: int = 4096
gamma: float = 0.97
num_steps: int = 8
compile_mode: str = "default" # Multi-task training is not compatible with cudagraphs
@dataclass
class MetaWorldMT10Args(MTBenchArgs):
# This config achieves 97 ~ 98% success rate within 10k steps (15-20 mins on A100)
env_name: str = "MTBench-meta-world-v2-mt10"
num_envs: int = 4096
num_eval_envs: int = 4096
num_steps: int = 8
gamma: float = 0.97
compile_mode: str = "default" # Multi-task training is not compatible with cudagraphs
@dataclass
class MetaWorldMT50Args(MTBenchArgs):
# FastTD3 + SimbaV2 achieves >90% success rate within 20k steps (80 mins on A100)
# Performance further improves with more training steps, slowly.
env_name: str = "MTBench-meta-world-v2-mt50"
num_envs: int = 8192
num_eval_envs: int = 8192
num_steps: int = 8
gamma: float = 0.99
compile_mode: str = "default" # Multi-task training is not compatible with cudagraphs
@dataclass
class G1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "G1JoystickFlatTerrain"
total_timesteps: int = 100000
@dataclass
class G1JoystickRoughTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "G1JoystickRoughTerrain"
total_timesteps: int = 100000
@dataclass
class T1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "T1JoystickFlatTerrain"
total_timesteps: int = 100000
@dataclass
class T1JoystickRoughTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "T1JoystickRoughTerrain"
total_timesteps: int = 100000
@dataclass
class T1LowDofJoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "T1LowDofJoystickFlatTerrain"
total_timesteps: int = 1000000
@dataclass
class T1LowDofJoystickRoughTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "T1LowDofJoystickRoughTerrain"
total_timesteps: int = 1000000
@dataclass
class CheetahRunArgs(MuJoCoPlaygroundArgs):
# NOTE: This config will work for most DMC tasks, though we haven't tested DMC extensively.
# Future research can consider using LayerNorm as we find it sometimes works better for DMC tasks.
env_name: str = "CheetahRun"
num_steps: int = 3
v_min: float = -500.0
v_max: float = 500.0
std_min: float = 0.1
policy_noise: float = 0.1
@dataclass
class Go1JoystickFlatTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "Go1JoystickFlatTerrain"
total_timesteps: int = 50000
std_min: float = 0.2
std_max: float = 0.8
policy_noise: float = 0.2
num_updates: int = 8
@dataclass
class Go1JoystickRoughTerrainArgs(MuJoCoPlaygroundArgs):
env_name: str = "Go1JoystickRoughTerrain"
total_timesteps: int = 50000
std_min: float = 0.2
std_max: float = 0.8
policy_noise: float = 0.2
num_updates: int = 8
@dataclass
class Go1GetupArgs(MuJoCoPlaygroundArgs):
env_name: str = "Go1Getup"
total_timesteps: int = 50000
std_min: float = 0.2
std_max: float = 0.8
policy_noise: float = 0.2
num_updates: int = 8
@dataclass
class LeapCubeReorientArgs(MuJoCoPlaygroundArgs):
env_name: str = "LeapCubeReorient"
num_steps: int = 3
gamma: float = 0.99
policy_noise: float = 0.2
v_min: float = -50.0
v_max: float = 50.0
use_cdq: bool = False
@dataclass
class LeapCubeRotateZAxisArgs(MuJoCoPlaygroundArgs):
env_name: str = "LeapCubeRotateZAxis"
num_steps: int = 1
policy_noise: float = 0.2
gamma: float = 0.99
v_min: float = -10.0
v_max: float = 10.0
use_cdq: bool = False
@dataclass
class IsaacLabArgs(BaseArgs):
v_min: float = -10.0
v_max: float = 10.0
buffer_size: int = 1024 * 10
num_envs: int = 4096
num_eval_envs: int = 4096
action_bounds: float = 1.0
std_max: float = 0.4
num_atoms: int = 251
render_interval: int = 0 # IsaacLab does not support rendering in our codebase
total_timesteps: int = 100000
@dataclass
class IsaacLiftCubeFrankaArgs(IsaacLabArgs):
# Value learning is unstable for Lift Cube task Due to brittle reward shaping
# Therefore, we need to disable bootstrap from 'reset_obs' in IsaacLab
# Higher UTD works better for manipulation tasks
env_name: str = "Isaac-Lift-Cube-Franka-v0"
num_updates: int = 8
v_min: float = -50.0
v_max: float = 50.0
std_max: float = 0.8
num_envs: int = 1024
num_eval_envs: int = 1024
action_bounds: float = 3.0
disable_bootstrap: bool = True
total_timesteps: int = 20000
@dataclass
class IsaacOpenDrawerFrankaArgs(IsaacLabArgs):
# Higher UTD works better for manipulation tasks
env_name: str = "Isaac-Open-Drawer-Franka-v0"
v_min: float = -50.0
v_max: float = 50.0
num_updates: int = 8
action_bounds: float = 3.0
total_timesteps: int = 20000
@dataclass
class IsaacVelocityFlatH1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Flat-H1-v0"
num_steps: int = 8
num_updates: int = 4
total_timesteps: int = 75000
@dataclass
class IsaacVelocityFlatG1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Flat-G1-v0"
num_steps: int = 8
num_updates: int = 4
total_timesteps: int = 50000
@dataclass
class IsaacVelocityRoughH1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Rough-H1-v0"
num_steps: int = 8
num_updates: int = 4
buffer_size: int = 1024 * 5 # To reduce memory usage
total_timesteps: int = 50000
@dataclass
class IsaacVelocityRoughG1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Rough-G1-v0"
num_steps: int = 8
num_updates: int = 4
buffer_size: int = 1024 * 5 # To reduce memory usage
total_timesteps: int = 50000
@dataclass
class IsaacReposeCubeAllegroDirectArgs(IsaacLabArgs):
env_name: str = "Isaac-Repose-Cube-Allegro-Direct-v0"
total_timesteps: int = 100000
v_min: float = -500.0
v_max: float = 500.0
@dataclass
class IsaacReposeCubeShadowDirectArgs(IsaacLabArgs):
env_name: str = "Isaac-Repose-Cube-Shadow-Direct-v0"
total_timesteps: int = 100000
v_min: float = -500.0
v_max: float = 500.0