diff --git a/README.md b/README.md index 054b848..9ccadf7 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ For more information, please see our [project webpage](https://younggyo.me/fast_ ## ❗ Updates +- **[Jul/07/2025]** Added support for multi-GPU training! See [Multi-GPU Training](#multi-gpu-training) section for details. + - **[Jul/02/2025]** Optimized codebase to speed up training around 10-30% when using a single RTX 4090 GPU. - **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/). @@ -242,6 +244,18 @@ We used a single Nvidia A100 80GB GPU for all experiments. Here are some remarks - When you encounter out-of-memory error with your GPU, our recommendation for reducing GPU usage is (i) smaller `buffer_size`, (ii) smaller `batch_size`, and then (iii) smaller `num_envs`. Because our codebase is assigning the whole replay buffer in GPU to reduce CPU-GPU transfer bottleneck, it usually has the largest GPU consumption, but usually less harmful to reduce. - Consider using `--compile_mode max-autotune` if you plan to run for many training steps. This may speed up training by up to 10% at the cost of a few additional minutes of heavy compilation. +## Multi-GPU Training +We support multi-GPU training. If your machine supports multiple GPUs, or specify multiple GPUs using `CUDA_VISIBLE_DEVICES`, and run `train_multigpu.py`, it will automatically use all GPUs to scale up training. + +**Important:** Our multi-GPU implementation launches the **same experiment independently on each GPU** rather than distributing parameters across GPUs. This means: +- Effective number of environments: `num_envs × num_gpus` +- Effective batch size: `batch_size × num_gpus` +- Effective buffer size: `buffer_size × num_gpus` + +Each GPU runs a complete copy of the training process, which scales up data collection and training throughput proportionally to the number of GPUs. + +For instance, running IsaacLab experiments with 4 GPUs and `num_envs=1024` will end up in similar results as experiments with 1 GPU with `num_envs=4096`. + ## 🛝 Playing with the FastTD3 training A Jupyter notebook (`training_notebook.ipynb`) is available to help you get started with: diff --git a/fast_td3/environments/isaaclab_env.py b/fast_td3/environments/isaaclab_env.py index 876fc69..2374ebc 100644 --- a/fast_td3/environments/isaaclab_env.py +++ b/fast_td3/environments/isaaclab_env.py @@ -2,13 +2,6 @@ from typing import Optional import gymnasium as gym import torch -from isaaclab.app import AppLauncher - -app_launcher = AppLauncher(headless=True) -simulation_app = app_launcher.app - -import isaaclab_tasks -from isaaclab_tasks.utils.parse_cfg import parse_env_cfg class IsaacLabEnv: @@ -22,6 +15,14 @@ class IsaacLabEnv: seed: int, action_bounds: Optional[float] = None, ): + from isaaclab.app import AppLauncher + + app_launcher = AppLauncher(headless=True, device=device) + simulation_app = app_launcher.app + + import isaaclab_tasks + from isaaclab_tasks.utils.parse_cfg import parse_env_cfg + env_cfg = parse_env_cfg( task_name, device=device, diff --git a/fast_td3/environments/mujoco_playground_env.py b/fast_td3/environments/mujoco_playground_env.py index 2200389..7f1d368 100644 --- a/fast_td3/environments/mujoco_playground_env.py +++ b/fast_td3/environments/mujoco_playground_env.py @@ -6,7 +6,15 @@ import mujoco class PlaygroundEvalEnvWrapper: - def __init__(self, eval_env, max_episode_steps, env_name, num_eval_envs, seed): + def __init__( + self, + eval_env, + max_episode_steps, + env_name, + num_eval_envs, + seed, + device_rank=None, + ): """ Wrapper used for evaluation / rendering environments. Note that this is different from training environments that are @@ -24,6 +32,11 @@ class PlaygroundEvalEnvWrapper: self.asymmetric_obs = False self.key = jax.random.PRNGKey(seed) + + if device_rank is not None: + gpu_devices = jax.devices("gpu") + self.key = jax.device_put(self.key, gpu_devices[device_rank]) + self.key_reset = jax.random.split(self.key, num_eval_envs) self.max_episode_steps = max_episode_steps @@ -118,7 +131,12 @@ def make_env( eval_env_cfg.push_config.magnitude_range = [0.0, 0.0] eval_env = registry.load(env_name, config=eval_env_cfg) eval_env = PlaygroundEvalEnvWrapper( - eval_env, eval_env_cfg.episode_length, env_name, num_eval_envs, seed + eval_env, + eval_env_cfg.episode_length, + env_name, + num_eval_envs, + seed, + device_rank=device_rank, ) render_env_cfg = registry.get_default_config(env_name) @@ -127,7 +145,12 @@ def make_env( render_env_cfg.push_config.magnitude_range = [0.0, 0.0] render_env = registry.load(env_name, config=render_env_cfg) render_env = PlaygroundEvalEnvWrapper( - render_env, render_env_cfg.episode_length, env_name, 1, seed + render_env, + render_env_cfg.episode_length, + env_name, + 1, + seed, + device_rank=device_rank, ) return train_env, eval_env, render_env diff --git a/fast_td3/fast_td3.py b/fast_td3/fast_td3.py index fcd6c37..fc4030f 100644 --- a/fast_td3/fast_td3.py +++ b/fast_td3/fast_td3.py @@ -234,6 +234,8 @@ class MultiTaskActor(Actor): ) def forward(self, obs: torch.Tensor) -> torch.Tensor: + # TODO: Optimize the code to be compatible with cudagraphs + # Currently in-place creation of task_indices is not compatible with cudagraphs task_ids_one_hot = obs[..., -self.num_tasks :] task_indices = torch.argmax(task_ids_one_hot, dim=1) task_embeddings = self.task_embedding(task_indices) @@ -251,6 +253,8 @@ class MultiTaskCritic(Critic): ) def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + # TODO: Optimize the code to be compatible with cudagraphs + # Currently in-place creation of task_indices is not compatible with cudagraphs task_ids_one_hot = obs[..., -self.num_tasks :] task_indices = torch.argmax(task_ids_one_hot, dim=1) task_embeddings = self.task_embedding(task_indices) diff --git a/fast_td3/fast_td3_simbav2.py b/fast_td3/fast_td3_simbav2.py index 0ca3820..2e1896d 100644 --- a/fast_td3/fast_td3_simbav2.py +++ b/fast_td3/fast_td3_simbav2.py @@ -510,6 +510,8 @@ class MultiTaskActor(Actor): ) def forward(self, obs: torch.Tensor) -> torch.Tensor: + # TODO: Optimize the code to be compatible with cudagraphs + # Currently in-place creation of task_indices is not compatible with cudagraphs task_ids_one_hot = obs[..., -self.num_tasks :] task_indices = torch.argmax(task_ids_one_hot, dim=1) task_embeddings = self.task_embedding(task_indices) @@ -527,6 +529,8 @@ class MultiTaskCritic(Critic): ) def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + # TODO: Optimize the code to be compatible with cudagraphs + # Currently in-place creation of task_indices is not compatible with cudagraphs task_ids_one_hot = obs[..., -self.num_tasks :] task_indices = torch.argmax(task_ids_one_hot, dim=1) task_embeddings = self.task_embedding(task_indices) diff --git a/fast_td3/fast_td3_utils.py b/fast_td3/fast_td3_utils.py index f16a4c0..273ac0e 100644 --- a/fast_td3/fast_td3_utils.py +++ b/fast_td3/fast_td3_utils.py @@ -4,6 +4,7 @@ from typing import Optional import torch import torch.nn as nn +import torch.distributed as dist from tensordict import TensorDict @@ -428,13 +429,15 @@ class EmpiricalNormalization(nn.Module): return self._std.squeeze(0).clone() @torch.no_grad() - def forward(self, x: torch.Tensor, center: bool = True) -> torch.Tensor: + def forward( + self, x: torch.Tensor, center: bool = True, update: bool = True + ) -> torch.Tensor: if x.shape[1:] != self._mean.shape[1:]: raise ValueError( f"Expected input of shape (*,{self._mean.shape[1:]}), got {x.shape}" ) - if self.training: + if self.training and update: self.update(x) if center: return (x - self._mean) / (self._std + self.eps) @@ -446,27 +449,46 @@ class EmpiricalNormalization(nn.Module): if self.until is not None and self.count >= self.until: return - batch_size = x.shape[0] - batch_mean = torch.mean(x, dim=0, keepdim=True) + if dist.is_available() and dist.is_initialized(): + # Calculate global batch size arithmetically + local_batch_size = x.shape[0] + world_size = dist.get_world_size() + global_batch_size = world_size * local_batch_size - # Update count - new_count = self.count + batch_size + # Calculate the stats + x_shifted = x - self._mean + local_sum_shifted = torch.sum(x_shifted, dim=0, keepdim=True) + local_sum_sq_shifted = torch.sum(x_shifted.pow(2), dim=0, keepdim=True) + + # Sync the stats across all processes + stats_to_sync = torch.cat([local_sum_shifted, local_sum_sq_shifted], dim=0) + dist.all_reduce(stats_to_sync, op=dist.ReduceOp.SUM) + global_sum_shifted, global_sum_sq_shifted = stats_to_sync + + # Calculate the mean and variance of the global batch + batch_mean_shifted = global_sum_shifted / global_batch_size + batch_var = ( + global_sum_sq_shifted / global_batch_size - batch_mean_shifted.pow(2) + ) + batch_mean = batch_mean_shifted + self._mean + + else: + global_batch_size = x.shape[0] + batch_mean = torch.mean(x, dim=0, keepdim=True) + batch_var = torch.var(x, dim=0, keepdim=True, unbiased=False) + + new_count = self.count + global_batch_size # Update mean delta = batch_mean - self._mean - self._mean += (batch_size / new_count) * delta + self._mean.copy_(self._mean + delta * (global_batch_size / new_count)) - # Compute batch variance - batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True) - delta2 = batch_mean - self._mean # uses updated mean - - # Parallel variance update (works even when previous count == 0) - m_a = self._var * self.count # previous aggregated M2 - m_b = batch_var * batch_size - M2 = m_a + m_b + delta2.pow(2) * (self.count * batch_size / new_count) + # Update variance + delta2 = batch_mean - self._mean + m_a = self._var * self.count + m_b = batch_var * global_batch_size + M2 = m_a + m_b + delta2.pow(2) * (self.count * global_batch_size / new_count) self._var.copy_(M2 / new_count) - - # Update std and count in-place to avoid expensive __setattr__ self._std.copy_(self._var.sqrt()) self.count.copy_(new_count) @@ -507,7 +529,13 @@ class RewardNormalizer(nn.Module): ): self.G = self.gamma * (1 - dones) * self.G + rewards self.G_rms.update(self.G.view(-1, 1)) - self.G_r_max = max(self.G_r_max, max(abs(self.G))) + + local_max = torch.max(torch.abs(self.G)) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(local_max, op=dist.ReduceOp.MAX) + + self.G_r_max = max(self.G_r_max, local_max) def forward(self, rewards: torch.Tensor) -> torch.Tensor: return self._scale_reward(rewards) @@ -608,7 +636,7 @@ class PerTaskEmpiricalNormalization(nn.Module): task_mean = self._mean[task_id] batch_mean = torch.mean(x_task, dim=0) delta = batch_mean - task_mean - self._mean[task_id] = task_mean + (batch_size / new_count) * delta + self._mean[task_id].copy_(task_mean + (batch_size / new_count) * delta) # Update variance using Chan's parallel algorithm if old_count > 0: @@ -616,13 +644,13 @@ class PerTaskEmpiricalNormalization(nn.Module): m_a = self._var[task_id] * old_count m_b = batch_var * batch_size M2 = m_a + m_b + (delta**2) * (old_count * batch_size / new_count) - self._var[task_id] = M2 / new_count + self._var[task_id].copy_(M2 / new_count) else: # For the first batch of this task - self._var[task_id] = torch.var(x_task, dim=0, unbiased=False) + self._var[task_id].copy_(torch.var(x_task, dim=0, unbiased=False)) - self._std[task_id] = torch.sqrt(self._var[task_id]) - self.count[task_id] = new_count + self._std[task_id].copy_(torch.sqrt(self._var[task_id])) + self.count[task_id].copy_(new_count) class PerTaskRewardNormalizer(nn.Module): @@ -735,11 +763,18 @@ def save_params( save_path, ): """Save model parameters and training configuration to disk.""" + + def get_ddp_state_dict(model): + """Get state dict from model, handling DDP wrapper if present.""" + if hasattr(model, "module"): + return model.module.state_dict() + return model.state_dict() + os.makedirs(os.path.dirname(save_path), exist_ok=True) save_dict = { - "actor_state_dict": cpu_state(actor.state_dict()), - "qnet_state_dict": cpu_state(qnet.state_dict()), - "qnet_target_state_dict": cpu_state(qnet_target.state_dict()), + "actor_state_dict": cpu_state(get_ddp_state_dict(actor)), + "qnet_state_dict": cpu_state(get_ddp_state_dict(qnet)), + "qnet_target_state_dict": cpu_state(get_ddp_state_dict(qnet_target)), "obs_normalizer_state": ( cpu_state(obs_normalizer.state_dict()) if hasattr(obs_normalizer, "state_dict") @@ -755,3 +790,24 @@ def save_params( } torch.save(save_dict, save_path, _use_new_zipfile_serialization=True) print(f"Saved parameters and configuration to {save_path}") + + +def get_ddp_state_dict(model): + """Get state dict from model, handling DDP wrapper if present.""" + if hasattr(model, "module"): + return model.module.state_dict() + return model.state_dict() + + +def load_ddp_state_dict(model, state_dict): + """Load state dict into model, handling DDP wrapper if present.""" + if hasattr(model, "module"): + model.module.load_state_dict(state_dict) + else: + model.load_state_dict(state_dict) + + +@torch.no_grad() +def mark_step(): + # call this once per iteration *before* any compiled function + torch.compiler.cudagraph_mark_step_begin() diff --git a/fast_td3/hyperparams.py b/fast_td3/hyperparams.py index 8605322..3b02c30 100644 --- a/fast_td3/hyperparams.py +++ b/fast_td3/hyperparams.py @@ -303,6 +303,7 @@ class MTBenchArgs(BaseArgs): num_eval_envs: int = 4096 gamma: float = 0.97 num_steps: int = 8 + compile_mode: str = "default" # Multi-task training is not compatible with cudagraphs @dataclass @@ -313,6 +314,7 @@ class MetaWorldMT10Args(MTBenchArgs): num_eval_envs: int = 4096 num_steps: int = 8 gamma: float = 0.97 + compile_mode: str = "default" # Multi-task training is not compatible with cudagraphs @dataclass @@ -324,6 +326,7 @@ class MetaWorldMT50Args(MTBenchArgs): num_eval_envs: int = 8192 num_steps: int = 8 gamma: float = 0.99 + compile_mode: str = "default" # Multi-task training is not compatible with cudagraphs @dataclass diff --git a/fast_td3/train.py b/fast_td3/train.py index ab4e689..f28701c 100644 --- a/fast_td3/train.py +++ b/fast_td3/train.py @@ -38,6 +38,7 @@ from fast_td3_utils import ( PerTaskRewardNormalizer, SimpleReplayBuffer, save_params, + mark_step, ) from hyperparams import get_args @@ -49,12 +50,6 @@ except ImportError: pass -@torch.no_grad() -def mark_step(): - # call this once per iteration *before* any compiled function - torch.compiler.cudagraph_mark_step_begin() - - def main(): args = get_args() print(args) @@ -313,7 +308,6 @@ def main(): noise_clip = args.noise_clip def evaluate(): - obs_normalizer.eval() num_eval_envs = eval_envs.num_envs episode_returns = torch.zeros(num_eval_envs, device=device) episode_lengths = torch.zeros(num_eval_envs, device=device) @@ -329,7 +323,7 @@ def main(): with torch.no_grad(), autocast( device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled ): - obs = normalize_obs(obs) + obs = normalize_obs(obs, update=False) actions = actor(obs) next_obs, rewards, dones, infos = eval_envs.step(actions.float()) @@ -352,12 +346,9 @@ def main(): break obs = next_obs - obs_normalizer.train() return episode_returns.mean().item(), episode_lengths.mean().item() def render_with_rollout(): - obs_normalizer.eval() - # Quick rollout for rendering if env_type == "humanoid_bench": obs = render_env.reset() @@ -374,7 +365,7 @@ def main(): with torch.no_grad(), autocast( device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled ): - obs = normalize_obs(obs) + obs = normalize_obs(obs, update=False) actions = actor(obs) next_obs, _, done, _ = render_env.step(actions.float()) if env_type == "mujoco_playground": @@ -390,8 +381,6 @@ def main(): if env_type == "mujoco_playground": renders = render_env.render_trajectory(renders) - - obs_normalizer.train() return renders def update_main(data, logs_dict): @@ -473,7 +462,6 @@ def main(): critic_grad_norm = torch.tensor(0.0, device=device) scaler.step(q_optimizer) scaler.update() - q_scheduler.step() logs_dict["critic_grad_norm"] = critic_grad_norm.detach() logs_dict["qf_loss"] = qf_loss.detach() @@ -512,7 +500,6 @@ def main(): actor_grad_norm = torch.tensor(0.0, device=device) scaler.step(actor_optimizer) scaler.update() - actor_scheduler.step() logs_dict["actor_grad_norm"] = actor_grad_norm.detach() logs_dict["actor_loss"] = actor_loss.detach() return logs_dict @@ -529,16 +516,12 @@ def main(): compile_mode = args.compile_mode update_main = torch.compile(update_main, mode=compile_mode) update_pol = torch.compile(update_pol, mode=compile_mode) - policy = torch.compile(policy, mode=compile_mode) - normalize_obs = torch.compile(obs_normalizer.forward, mode=compile_mode) - normalize_critic_obs = torch.compile( - critic_obs_normalizer.forward, mode=compile_mode - ) + policy = torch.compile(policy, mode=None) + normalize_obs = torch.compile(obs_normalizer.forward, mode=None) + normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=None) if args.reward_normalization: - update_stats = torch.compile( - reward_normalizer.update_stats, mode=compile_mode - ) - normalize_reward = torch.compile(reward_normalizer.forward, mode=compile_mode) + update_stats = torch.compile(reward_normalizer.update_stats, mode=None) + normalize_reward = torch.compile(reward_normalizer.forward, mode=None) else: normalize_obs = obs_normalizer.forward normalize_critic_obs = critic_obs_normalizer.forward @@ -637,10 +620,9 @@ def main(): if envs.asymmetric_obs: critic_obs = next_critic_obs - batch_size = args.batch_size // args.num_envs if global_step > args.learning_starts: for i in range(args.num_updates): - data = rb.sample(batch_size) + data = rb.sample(max(1, args.batch_size // args.num_envs)) data["observations"] = normalize_obs(data["observations"]) data["next"]["observations"] = normalize_obs( data["next"]["observations"] @@ -702,19 +684,14 @@ def main(): and global_step % args.render_interval == 0 ): renders = render_with_rollout() - if args.use_wandb: - wandb.log( - { - "render_video": wandb.Video( - np.array(renders).transpose( - 0, 3, 1, 2 - ), # Convert to (T, C, H, W) format - fps=30, - format="gif", - ) - }, - step=global_step, - ) + render_video = wandb.Video( + np.array(renders).transpose( + 0, 3, 1, 2 + ), # Convert to (T, C, H, W) format + fps=30, + format="gif", + ) + logs["render_video"] = render_video if args.use_wandb: wandb.log( { @@ -745,6 +722,8 @@ def main(): ) global_step += 1 + actor_scheduler.step() + q_scheduler.step() pbar.update(1) save_params( diff --git a/fast_td3/train_multigpu.py b/fast_td3/train_multigpu.py new file mode 100644 index 0000000..aa65c59 --- /dev/null +++ b/fast_td3/train_multigpu.py @@ -0,0 +1,810 @@ +import os +import sys + +os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" +os.environ["OMP_NUM_THREADS"] = "1" +if sys.platform != "darwin": + os.environ["MUJOCO_GL"] = "egl" +else: + os.environ["MUJOCO_GL"] = "glfw" +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" +os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest" + +import random +import time +import math + +import tqdm +import wandb +import numpy as np + +try: + # Required for avoiding IsaacGym import error + import isaacgym +except ImportError: + pass + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.amp import autocast, GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp + +from tensordict import TensorDict + +from fast_td3_utils import ( + EmpiricalNormalization, + RewardNormalizer, + PerTaskRewardNormalizer, + SimpleReplayBuffer, + save_params, + get_ddp_state_dict, + load_ddp_state_dict, + mark_step, +) +from hyperparams import get_args + +torch.set_float32_matmul_precision("high") + +try: + import jax.numpy as jnp +except ImportError: + pass + + +def setup_distributed(rank: int, world_size: int): + os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost") + os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355") + is_distributed = world_size > 1 + if is_distributed: + print( + f"Initializing distributed training with rank {rank}, world size {world_size}" + ) + torch.distributed.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=rank + ) + torch.cuda.set_device(rank) + return is_distributed + + +def main(rank: int, world_size: int): + is_distributed = setup_distributed(rank, world_size) + + args = get_args() + if rank == 0: + print(args) + run_name = f"{args.env_name}__{args.exp_name}__{args.seed}" + + amp_enabled = args.amp and args.cuda and torch.cuda.is_available() + amp_device_type = ( + f"cuda:{rank}" + if args.cuda and torch.cuda.is_available() + else "mps" if args.cuda and torch.backends.mps.is_available() else "cpu" + ) + amp_dtype = torch.bfloat16 if args.amp_dtype == "bf16" else torch.float16 + + scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16) + + if args.use_wandb and rank == 0: + wandb.init( + project=args.project, + name=run_name, + config=vars(args), + save_code=True, + ) + + # Use different seeds per rank to avoid synchronization issues + random.seed(args.seed + rank) + np.random.seed(args.seed + rank) + torch.manual_seed(args.seed + rank) + torch.backends.cudnn.deterministic = args.torch_deterministic + + if not args.cuda: + device = torch.device("cpu") + else: + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + elif torch.backends.mps.is_available(): + device = torch.device(f"mps:{rank}") + else: + raise ValueError("No GPU available") + print(f"Using device: {device}") + + if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"): + from environments.humanoid_bench_env import HumanoidBenchEnv + + env_type = "humanoid_bench" + envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device) + eval_envs = envs + render_env = HumanoidBenchEnv( + args.env_name, 1, render_mode="rgb_array", device=device + ) + elif args.env_name.startswith("Isaac-"): + from environments.isaaclab_env import IsaacLabEnv + + env_type = "isaaclab" + envs = IsaacLabEnv( + args.env_name, + f"cuda:{rank}", + args.num_envs, + args.seed + rank, + action_bounds=args.action_bounds, + ) + eval_envs = envs + render_env = envs + elif args.env_name.startswith("MTBench-"): + from environments.mtbench_env import MTBenchEnv + + env_name = "-".join(args.env_name.split("-")[1:]) + env_type = "mtbench" + envs = MTBenchEnv(env_name, rank, args.num_envs, args.seed + rank) + eval_envs = envs + render_env = envs + else: + from environments.mujoco_playground_env import make_env + + # TODO: Check if re-using same envs for eval could reduce memory usage + env_type = "mujoco_playground" + envs, eval_envs, render_env = make_env( + args.env_name, + args.seed + rank, + args.num_envs, + args.num_eval_envs, + rank, + use_tuned_reward=args.use_tuned_reward, + use_domain_randomization=args.use_domain_randomization, + use_push_randomization=args.use_push_randomization, + ) + + n_act = envs.num_actions + n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0] + if envs.asymmetric_obs: + n_critic_obs = ( + envs.num_privileged_obs + if type(envs.num_privileged_obs) == int + else envs.num_privileged_obs[0] + ) + else: + n_critic_obs = n_obs + action_low, action_high = -1.0, 1.0 + + if args.obs_normalization: + obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device) + critic_obs_normalizer = EmpiricalNormalization( + shape=n_critic_obs, device=device + ) + else: + obs_normalizer = nn.Identity() + critic_obs_normalizer = nn.Identity() + + if args.reward_normalization: + if env_type in ["mtbench"]: + reward_normalizer = PerTaskRewardNormalizer( + num_tasks=envs.num_tasks, + gamma=args.gamma, + device=device, + g_max=min(abs(args.v_min), abs(args.v_max)), + ) + else: + reward_normalizer = RewardNormalizer( + gamma=args.gamma, + device=device, + g_max=min(abs(args.v_min), abs(args.v_max)), + ) + else: + reward_normalizer = nn.Identity() + + actor_kwargs = { + "n_obs": n_obs, + "n_act": n_act, + "num_envs": args.num_envs, + "device": device, + "init_scale": args.init_scale, + "hidden_dim": args.actor_hidden_dim, + } + critic_kwargs = { + "n_obs": n_critic_obs, + "n_act": n_act, + "num_atoms": args.num_atoms, + "v_min": args.v_min, + "v_max": args.v_max, + "hidden_dim": args.critic_hidden_dim, + "device": device, + } + + if env_type == "mtbench": + actor_kwargs["n_obs"] = n_obs - envs.num_tasks + args.task_embedding_dim + critic_kwargs["n_obs"] = n_critic_obs - envs.num_tasks + args.task_embedding_dim + actor_kwargs["num_tasks"] = envs.num_tasks + actor_kwargs["task_embedding_dim"] = args.task_embedding_dim + critic_kwargs["num_tasks"] = envs.num_tasks + critic_kwargs["task_embedding_dim"] = args.task_embedding_dim + + if args.agent == "fasttd3": + if env_type in ["mtbench"]: + from fast_td3 import MultiTaskActor, MultiTaskCritic + + actor_cls = MultiTaskActor + critic_cls = MultiTaskCritic + else: + from fast_td3 import Actor, Critic + + actor_cls = Actor + critic_cls = Critic + + if rank == 0: + print("Using FastTD3") + elif args.agent == "fasttd3_simbav2": + if env_type in ["mtbench"]: + from fast_td3_simbav2 import MultiTaskActor, MultiTaskCritic + + actor_cls = MultiTaskActor + critic_cls = MultiTaskCritic + else: + from fast_td3_simbav2 import Actor, Critic + + actor_cls = Actor + critic_cls = Critic + + if rank == 0: + print("Using FastTD3 + SimbaV2") + actor_kwargs.pop("init_scale") + actor_kwargs.update( + { + "scaler_init": math.sqrt(2.0 / args.actor_hidden_dim), + "scaler_scale": math.sqrt(2.0 / args.actor_hidden_dim), + "alpha_init": 1.0 / (args.actor_num_blocks + 1), + "alpha_scale": 1.0 / math.sqrt(args.actor_hidden_dim), + "expansion": 4, + "c_shift": 3.0, + "num_blocks": args.actor_num_blocks, + } + ) + critic_kwargs.update( + { + "scaler_init": math.sqrt(2.0 / args.critic_hidden_dim), + "scaler_scale": math.sqrt(2.0 / args.critic_hidden_dim), + "alpha_init": 1.0 / (args.critic_num_blocks + 1), + "alpha_scale": 1.0 / math.sqrt(args.critic_hidden_dim), + "num_blocks": args.critic_num_blocks, + "expansion": 4, + "c_shift": 3.0, + } + ) + else: + raise ValueError(f"Agent {args.agent} not supported") + + actor = actor_cls(**actor_kwargs) + if is_distributed: + actor = DDP(actor, device_ids=[rank]) + if env_type in ["mtbench"]: + # Python 3.8 doesn't support 'from_module' in tensordict + policy = actor.module.explore if hasattr(actor, "module") else actor.explore + else: + from tensordict import from_module + + actor_detach = actor_cls(**actor_kwargs) + # Copy params to actor_detach without grad + from_module(actor.module if hasattr(actor, "module") else actor).data.to_module( + actor_detach + ) + policy = actor_detach.explore + + qnet = critic_cls(**critic_kwargs) + if is_distributed: + qnet = DDP(qnet, device_ids=[rank]) + qnet_target = critic_cls(**critic_kwargs) # Create a separate instance + qnet_target.load_state_dict(get_ddp_state_dict(qnet)) + + q_optimizer = optim.AdamW( + list(qnet.parameters()), + lr=torch.tensor(args.critic_learning_rate, device=device), + weight_decay=args.weight_decay, + ) + actor_optimizer = optim.AdamW( + list(actor.parameters()), + lr=torch.tensor(args.actor_learning_rate, device=device), + weight_decay=args.weight_decay, + ) + + # Add learning rate schedulers + q_scheduler = optim.lr_scheduler.CosineAnnealingLR( + q_optimizer, + T_max=args.total_timesteps, + eta_min=torch.tensor(args.critic_learning_rate_end, device=device), + ) + actor_scheduler = optim.lr_scheduler.CosineAnnealingLR( + actor_optimizer, + T_max=args.total_timesteps, + eta_min=torch.tensor(args.actor_learning_rate_end, device=device), + ) + + rb = SimpleReplayBuffer( + n_env=args.num_envs, + buffer_size=args.buffer_size, + n_obs=n_obs, + n_act=n_act, + n_critic_obs=n_critic_obs, + asymmetric_obs=envs.asymmetric_obs, + playground_mode=env_type == "mujoco_playground", + n_steps=args.num_steps, + gamma=args.gamma, + device=device, + ) + + policy_noise = args.policy_noise + noise_clip = args.noise_clip + + def evaluate(): + num_eval_envs = eval_envs.num_envs + episode_returns = torch.zeros(num_eval_envs, device=device) + episode_lengths = torch.zeros(num_eval_envs, device=device) + done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device) + + if env_type == "isaaclab": + obs = eval_envs.reset(random_start_init=False) + else: + obs = eval_envs.reset() + + # Run for a fixed number of steps + for i in range(eval_envs.max_episode_steps): + with torch.no_grad(), autocast( + device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled + ): + obs = normalize_obs(obs, update=False) + actions = actor(obs) + + next_obs, rewards, dones, infos = eval_envs.step(actions.float()) + + if env_type == "mtbench": + # We only report success rate in MTBench evaluation + rewards = ( + infos["episode"]["success"].float() if "episode" in infos else 0.0 + ) + episode_returns = torch.where( + ~done_masks, episode_returns + rewards, episode_returns + ) + episode_lengths = torch.where( + ~done_masks, episode_lengths + 1, episode_lengths + ) + if env_type == "mtbench" and "episode" in infos: + dones = dones | infos["episode"]["success"] + done_masks = torch.logical_or(done_masks, dones) + if done_masks.all(): + break + obs = next_obs + + return episode_returns.mean(), episode_lengths.mean() + + def render_with_rollout(): + # Quick rollout for rendering + if env_type == "humanoid_bench": + obs = render_env.reset() + renders = [render_env.render()] + elif env_type in ["isaaclab", "mtbench"]: + raise NotImplementedError( + "We don't support rendering for IsaacLab and MTBench environments" + ) + else: + obs = render_env.reset() + render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]]) + renders = [render_env.state] + for i in range(render_env.max_episode_steps): + with torch.no_grad(), autocast( + device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled + ): + obs = normalize_obs(obs, update=False) + actions = actor(obs) + next_obs, _, done, _ = render_env.step(actions.float()) + if env_type == "mujoco_playground": + render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]]) + if i % 2 == 0: + if env_type == "humanoid_bench": + renders.append(render_env.render()) + else: + renders.append(render_env.state) + if done.any(): + break + obs = next_obs + + if env_type == "mujoco_playground": + renders = render_env.render_trajectory(renders) + return renders + + def update_main(data, logs_dict): + with autocast( + device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled + ): + observations = data["observations"] + next_observations = data["next"]["observations"] + if envs.asymmetric_obs: + critic_observations = data["critic_observations"] + next_critic_observations = data["next"]["critic_observations"] + else: + critic_observations = observations + next_critic_observations = next_observations + actions = data["actions"] + rewards = data["next"]["rewards"] + dones = data["next"]["dones"].bool() + truncations = data["next"]["truncations"].bool() + if args.disable_bootstrap: + bootstrap = (~dones).float() + else: + bootstrap = (truncations | ~dones).float() + + clipped_noise = torch.randn_like(actions) + clipped_noise = clipped_noise.mul(policy_noise).clamp( + -noise_clip, noise_clip + ) + + next_state_actions = (actor(next_observations) + clipped_noise).clamp( + action_low, action_high + ) + discount = args.gamma ** data["next"]["effective_n_steps"] + + with torch.no_grad(): + qf1_next_target_projected, qf2_next_target_projected = ( + qnet_target.projection( + next_critic_observations, + next_state_actions, + rewards, + bootstrap, + discount, + ) + ) + qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected) + qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected) + if args.use_cdq: + qf_next_target_dist = torch.where( + qf1_next_target_value.unsqueeze(1) + < qf2_next_target_value.unsqueeze(1), + qf1_next_target_projected, + qf2_next_target_projected, + ) + qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist + else: + qf1_next_target_dist, qf2_next_target_dist = ( + qf1_next_target_projected, + qf2_next_target_projected, + ) + + qf1, qf2 = qnet(critic_observations, actions) + qf1_loss = -torch.sum( + qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1 + ).mean() + qf2_loss = -torch.sum( + qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1 + ).mean() + qf_loss = qf1_loss + qf2_loss + + q_optimizer.zero_grad(set_to_none=True) + scaler.scale(qf_loss).backward() + scaler.unscale_(q_optimizer) + + if args.use_grad_norm_clipping: + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + qnet.parameters(), + max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), + ) + else: + critic_grad_norm = torch.tensor(0.0, device=device) + scaler.step(q_optimizer) + scaler.update() + + logs_dict["critic_grad_norm"] = critic_grad_norm.detach() + logs_dict["qf_loss"] = qf_loss.detach() + logs_dict["qf_max"] = qf1_next_target_value.max().detach() + logs_dict["qf_min"] = qf1_next_target_value.min().detach() + return logs_dict + + def update_pol(data, logs_dict): + with autocast( + device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled + ): + critic_observations = ( + data["critic_observations"] + if envs.asymmetric_obs + else data["observations"] + ) + + qf1, qf2 = qnet(critic_observations, actor(data["observations"])) + qf1_value = ( + qnet.module.get_value(F.softmax(qf1, dim=1)) + if hasattr(qnet, "module") + else qnet.get_value(F.softmax(qf1, dim=1)) + ) + qf2_value = ( + qnet.module.get_value(F.softmax(qf2, dim=1)) + if hasattr(qnet, "module") + else qnet.get_value(F.softmax(qf2, dim=1)) + ) + if args.use_cdq: + qf_value = torch.minimum(qf1_value, qf2_value) + else: + qf_value = (qf1_value + qf2_value) / 2.0 + actor_loss = -qf_value.mean() + + actor_optimizer.zero_grad(set_to_none=True) + scaler.scale(actor_loss).backward() + scaler.unscale_(actor_optimizer) + if args.use_grad_norm_clipping: + actor_grad_norm = torch.nn.utils.clip_grad_norm_( + actor.parameters(), + max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), + ) + else: + actor_grad_norm = torch.tensor(0.0, device=device) + scaler.step(actor_optimizer) + scaler.update() + logs_dict["actor_grad_norm"] = actor_grad_norm.detach() + logs_dict["actor_loss"] = actor_loss.detach() + return logs_dict + + @torch.no_grad() + def soft_update(src, tgt, tau: float): + # Handle DDP module by accessing .module attribute + src_module = src.module if hasattr(src, "module") else src + tgt_module = tgt.module if hasattr(tgt, "module") else tgt + + src_ps = [p.data for p in src_module.parameters()] + tgt_ps = [p.data for p in tgt_module.parameters()] + + torch._foreach_mul_(tgt_ps, 1.0 - tau) + torch._foreach_add_(tgt_ps, src_ps, alpha=tau) + + if args.compile: + compile_mode = args.compile_mode + update_main = torch.compile(update_main, mode=compile_mode) + update_pol = torch.compile(update_pol, mode=compile_mode) + policy = torch.compile(policy, mode=None) + normalize_obs = torch.compile(obs_normalizer.forward, mode=None) + normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=None) + if args.reward_normalization: + update_stats = torch.compile(reward_normalizer.update_stats, mode=None) + normalize_reward = torch.compile(reward_normalizer.forward, mode=None) + else: + normalize_obs = obs_normalizer.forward + normalize_critic_obs = critic_obs_normalizer.forward + if args.reward_normalization: + update_stats = reward_normalizer.update_stats + normalize_reward = reward_normalizer.forward + + if envs.asymmetric_obs: + obs, critic_obs = envs.reset_with_critic_obs() + critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float) + else: + obs = envs.reset() + if args.checkpoint_path: + # Load checkpoint if specified + torch_checkpoint = torch.load( + f"{args.checkpoint_path}", map_location=device, weights_only=False + ) + load_ddp_state_dict(actor, torch_checkpoint["actor_state_dict"]) + if torch_checkpoint["obs_normalizer_state"] is not None: + obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"]) + if torch_checkpoint["critic_obs_normalizer_state"] is not None: + critic_obs_normalizer.load_state_dict( + torch_checkpoint["critic_obs_normalizer_state"] + ) + load_ddp_state_dict(qnet, torch_checkpoint["qnet_state_dict"]) + qnet_target.load_state_dict(torch_checkpoint["qnet_target_state_dict"]) + global_step = torch_checkpoint["global_step"] + else: + global_step = 0 + + dones = None + pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step) + start_time = None + desc = "" + + while global_step < args.total_timesteps: + mark_step() + logs_dict = TensorDict() + if ( + start_time is None + and global_step >= args.measure_burnin + args.learning_starts + ): + start_time = time.time() + measure_burnin = global_step + + with torch.no_grad(), autocast( + device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled + ): + norm_obs = normalize_obs(obs) + actions = policy(obs=norm_obs, dones=dones) + + next_obs, rewards, dones, infos = envs.step(actions.float()) + truncations = infos["time_outs"] + + if args.reward_normalization: + if env_type == "mtbench": + task_ids_one_hot = obs[..., -envs.num_tasks :] + task_indices = torch.argmax(task_ids_one_hot, dim=1) + update_stats(rewards, dones.float(), task_ids=task_indices) + else: + update_stats(rewards, dones.float()) + + if envs.asymmetric_obs: + next_critic_obs = infos["observations"]["critic"] + # Compute 'true' next_obs and next_critic_obs for saving + true_next_obs = torch.where( + dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs + ) + if envs.asymmetric_obs: + true_next_critic_obs = torch.where( + dones[:, None] > 0, + infos["observations"]["raw"]["critic_obs"], + next_critic_obs, + ) + + transition = TensorDict( + { + "observations": obs, + "actions": torch.as_tensor(actions, device=device, dtype=torch.float), + "next": { + "observations": true_next_obs, + "rewards": torch.as_tensor( + rewards, device=device, dtype=torch.float + ), + "truncations": truncations.long(), + "dones": dones.long(), + }, + }, + batch_size=(envs.num_envs,), + device=device, + ) + if envs.asymmetric_obs: + transition["critic_observations"] = critic_obs + transition["next"]["critic_observations"] = true_next_critic_obs + rb.extend(transition) + + obs = next_obs + if envs.asymmetric_obs: + critic_obs = next_critic_obs + + if global_step > args.learning_starts: + for i in range(args.num_updates): + data = rb.sample(max(1, args.batch_size // args.num_envs)) + data["observations"] = normalize_obs(data["observations"]) + data["next"]["observations"] = normalize_obs( + data["next"]["observations"] + ) + if envs.asymmetric_obs: + data["critic_observations"] = normalize_critic_obs( + data["critic_observations"] + ) + data["next"]["critic_observations"] = normalize_critic_obs( + data["next"]["critic_observations"] + ) + raw_rewards = data["next"]["rewards"] + if env_type in ["mtbench"] and args.reward_normalization: + # Multi-task reward normalization + task_ids_one_hot = data["observations"][..., -envs.num_tasks :] + task_indices = torch.argmax(task_ids_one_hot, dim=1) + data["next"]["rewards"] = normalize_reward( + raw_rewards, task_ids=task_indices + ) + else: + data["next"]["rewards"] = normalize_reward(raw_rewards) + + logs_dict = update_main(data, logs_dict) + if args.num_updates > 1: + if i % args.policy_frequency == 1: + logs_dict = update_pol(data, logs_dict) + else: + if global_step % args.policy_frequency == 0: + logs_dict = update_pol(data, logs_dict) + + soft_update(qnet, qnet_target, args.tau) + + if global_step % 100 == 0 and start_time is not None: + speed = (global_step - measure_burnin) / (time.time() - start_time) + if rank == 0: + pbar.set_description(f"{speed: 4.4f} sps, " + desc) + with torch.no_grad(): + logs = { + "actor_loss": logs_dict["actor_loss"].mean(), + "qf_loss": logs_dict["qf_loss"].mean(), + "qf_max": logs_dict["qf_max"].mean(), + "qf_min": logs_dict["qf_min"].mean(), + "actor_grad_norm": logs_dict["actor_grad_norm"].mean(), + "critic_grad_norm": logs_dict["critic_grad_norm"].mean(), + "env_rewards": rewards.mean(), + "buffer_rewards": raw_rewards.mean(), + } + + if args.eval_interval > 0 and global_step % args.eval_interval == 0: + local_eval_avg_return, local_eval_avg_length = evaluate() + eval_results = torch.tensor( + [local_eval_avg_return, local_eval_avg_length], + device=device, + ) + if is_distributed: + torch.distributed.all_reduce( + eval_results, op=torch.distributed.ReduceOp.AVG + ) + + if rank == 0: + global_avg_return = eval_results[0].item() + global_avg_length = eval_results[1].item() + print( + f"Evaluating at global step {global_step}: Avg Return={global_avg_return:.2f}" + ) + logs["eval_avg_return"] = global_avg_return + logs["eval_avg_length"] = global_avg_length + + if env_type in ["humanoid_bench", "isaaclab", "mtbench"]: + # NOTE: Hacky way of evaluating performance, but just works + obs = envs.reset() + + if ( + args.render_interval > 0 + and global_step % args.render_interval == 0 + ): + renders = render_with_rollout() + render_video = wandb.Video( + np.array(renders).transpose( + 0, 3, 1, 2 + ), # Convert to (T, C, H, W) format + fps=30, + format="gif", + ) + logs["render_video"] = render_video + + if args.use_wandb and rank == 0: + wandb.log( + { + "speed": speed, + "frame": global_step * args.num_envs, + "critic_lr": q_scheduler.get_last_lr()[0], + "actor_lr": actor_scheduler.get_last_lr()[0], + **logs, + }, + step=global_step, + ) + + if ( + args.save_interval > 0 + and global_step > 0 + and global_step % args.save_interval == 0 + and rank == 0 + ): + print(f"Saving model at global step {global_step}") + save_params( + global_step, + actor, + qnet, + qnet_target, + obs_normalizer, + critic_obs_normalizer, + args, + f"models/{run_name}_{global_step}.pt", + ) + + global_step += 1 + actor_scheduler.step() + q_scheduler.step() + if rank == 0: + pbar.update(1) + + save_params( + global_step, + actor, + qnet, + qnet_target, + obs_normalizer, + critic_obs_normalizer, + args, + f"models/{run_name}_final.pt", + ) + + # Cleanup distributed training + if is_distributed: + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + world_size = torch.cuda.device_count() + mp.spawn(main, args=(world_size,), nprocs=world_size)