Support Multi-GPU Training (#22)

- Change in isaaclab_env wrapper to explicitly state GPU for each simulation
- Removing jax cache to support multi-gpu environment launch in MuJoCo Playground
- Removing .train() and .eval() in evaluation and rendering to avoid deadlock in multi-gpu training
- Supporting synchronous normalization for multi-gpu training
This commit is contained in:
Younggyo Seo 2025-07-07 10:24:42 -07:00 committed by GitHub
parent 83907422a3
commit 51c55d4a8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 970 additions and 76 deletions

View File

@ -11,6 +11,8 @@ For more information, please see our [project webpage](https://younggyo.me/fast_
## ❗ Updates ## ❗ Updates
- **[Jul/07/2025]** Added support for multi-GPU training! See [Multi-GPU Training](#multi-gpu-training) section for details.
- **[Jul/02/2025]** Optimized codebase to speed up training around 10-30% when using a single RTX 4090 GPU. - **[Jul/02/2025]** Optimized codebase to speed up training around 10-30% when using a single RTX 4090 GPU.
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/). - **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/).
@ -242,6 +244,18 @@ We used a single Nvidia A100 80GB GPU for all experiments. Here are some remarks
- When you encounter out-of-memory error with your GPU, our recommendation for reducing GPU usage is (i) smaller `buffer_size`, (ii) smaller `batch_size`, and then (iii) smaller `num_envs`. Because our codebase is assigning the whole replay buffer in GPU to reduce CPU-GPU transfer bottleneck, it usually has the largest GPU consumption, but usually less harmful to reduce. - When you encounter out-of-memory error with your GPU, our recommendation for reducing GPU usage is (i) smaller `buffer_size`, (ii) smaller `batch_size`, and then (iii) smaller `num_envs`. Because our codebase is assigning the whole replay buffer in GPU to reduce CPU-GPU transfer bottleneck, it usually has the largest GPU consumption, but usually less harmful to reduce.
- Consider using `--compile_mode max-autotune` if you plan to run for many training steps. This may speed up training by up to 10% at the cost of a few additional minutes of heavy compilation. - Consider using `--compile_mode max-autotune` if you plan to run for many training steps. This may speed up training by up to 10% at the cost of a few additional minutes of heavy compilation.
## Multi-GPU Training
We support multi-GPU training. If your machine supports multiple GPUs, or specify multiple GPUs using `CUDA_VISIBLE_DEVICES`, and run `train_multigpu.py`, it will automatically use all GPUs to scale up training.
**Important:** Our multi-GPU implementation launches the **same experiment independently on each GPU** rather than distributing parameters across GPUs. This means:
- Effective number of environments: `num_envs × num_gpus`
- Effective batch size: `batch_size × num_gpus`
- Effective buffer size: `buffer_size × num_gpus`
Each GPU runs a complete copy of the training process, which scales up data collection and training throughput proportionally to the number of GPUs.
For instance, running IsaacLab experiments with 4 GPUs and `num_envs=1024` will end up in similar results as experiments with 1 GPU with `num_envs=4096`.
## 🛝 Playing with the FastTD3 training ## 🛝 Playing with the FastTD3 training
A Jupyter notebook (`training_notebook.ipynb`) is available to help you get started with: A Jupyter notebook (`training_notebook.ipynb`) is available to help you get started with:

View File

@ -2,13 +2,6 @@ from typing import Optional
import gymnasium as gym import gymnasium as gym
import torch import torch
from isaaclab.app import AppLauncher
app_launcher = AppLauncher(headless=True)
simulation_app = app_launcher.app
import isaaclab_tasks
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
class IsaacLabEnv: class IsaacLabEnv:
@ -22,6 +15,14 @@ class IsaacLabEnv:
seed: int, seed: int,
action_bounds: Optional[float] = None, action_bounds: Optional[float] = None,
): ):
from isaaclab.app import AppLauncher
app_launcher = AppLauncher(headless=True, device=device)
simulation_app = app_launcher.app
import isaaclab_tasks
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
env_cfg = parse_env_cfg( env_cfg = parse_env_cfg(
task_name, task_name,
device=device, device=device,

View File

@ -6,7 +6,15 @@ import mujoco
class PlaygroundEvalEnvWrapper: class PlaygroundEvalEnvWrapper:
def __init__(self, eval_env, max_episode_steps, env_name, num_eval_envs, seed): def __init__(
self,
eval_env,
max_episode_steps,
env_name,
num_eval_envs,
seed,
device_rank=None,
):
""" """
Wrapper used for evaluation / rendering environments. Wrapper used for evaluation / rendering environments.
Note that this is different from training environments that are Note that this is different from training environments that are
@ -24,6 +32,11 @@ class PlaygroundEvalEnvWrapper:
self.asymmetric_obs = False self.asymmetric_obs = False
self.key = jax.random.PRNGKey(seed) self.key = jax.random.PRNGKey(seed)
if device_rank is not None:
gpu_devices = jax.devices("gpu")
self.key = jax.device_put(self.key, gpu_devices[device_rank])
self.key_reset = jax.random.split(self.key, num_eval_envs) self.key_reset = jax.random.split(self.key, num_eval_envs)
self.max_episode_steps = max_episode_steps self.max_episode_steps = max_episode_steps
@ -118,7 +131,12 @@ def make_env(
eval_env_cfg.push_config.magnitude_range = [0.0, 0.0] eval_env_cfg.push_config.magnitude_range = [0.0, 0.0]
eval_env = registry.load(env_name, config=eval_env_cfg) eval_env = registry.load(env_name, config=eval_env_cfg)
eval_env = PlaygroundEvalEnvWrapper( eval_env = PlaygroundEvalEnvWrapper(
eval_env, eval_env_cfg.episode_length, env_name, num_eval_envs, seed eval_env,
eval_env_cfg.episode_length,
env_name,
num_eval_envs,
seed,
device_rank=device_rank,
) )
render_env_cfg = registry.get_default_config(env_name) render_env_cfg = registry.get_default_config(env_name)
@ -127,7 +145,12 @@ def make_env(
render_env_cfg.push_config.magnitude_range = [0.0, 0.0] render_env_cfg.push_config.magnitude_range = [0.0, 0.0]
render_env = registry.load(env_name, config=render_env_cfg) render_env = registry.load(env_name, config=render_env_cfg)
render_env = PlaygroundEvalEnvWrapper( render_env = PlaygroundEvalEnvWrapper(
render_env, render_env_cfg.episode_length, env_name, 1, seed render_env,
render_env_cfg.episode_length,
env_name,
1,
seed,
device_rank=device_rank,
) )
return train_env, eval_env, render_env return train_env, eval_env, render_env

View File

@ -234,6 +234,8 @@ class MultiTaskActor(Actor):
) )
def forward(self, obs: torch.Tensor) -> torch.Tensor: def forward(self, obs: torch.Tensor) -> torch.Tensor:
# TODO: Optimize the code to be compatible with cudagraphs
# Currently in-place creation of task_indices is not compatible with cudagraphs
task_ids_one_hot = obs[..., -self.num_tasks :] task_ids_one_hot = obs[..., -self.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1) task_indices = torch.argmax(task_ids_one_hot, dim=1)
task_embeddings = self.task_embedding(task_indices) task_embeddings = self.task_embedding(task_indices)
@ -251,6 +253,8 @@ class MultiTaskCritic(Critic):
) )
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
# TODO: Optimize the code to be compatible with cudagraphs
# Currently in-place creation of task_indices is not compatible with cudagraphs
task_ids_one_hot = obs[..., -self.num_tasks :] task_ids_one_hot = obs[..., -self.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1) task_indices = torch.argmax(task_ids_one_hot, dim=1)
task_embeddings = self.task_embedding(task_indices) task_embeddings = self.task_embedding(task_indices)

View File

@ -510,6 +510,8 @@ class MultiTaskActor(Actor):
) )
def forward(self, obs: torch.Tensor) -> torch.Tensor: def forward(self, obs: torch.Tensor) -> torch.Tensor:
# TODO: Optimize the code to be compatible with cudagraphs
# Currently in-place creation of task_indices is not compatible with cudagraphs
task_ids_one_hot = obs[..., -self.num_tasks :] task_ids_one_hot = obs[..., -self.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1) task_indices = torch.argmax(task_ids_one_hot, dim=1)
task_embeddings = self.task_embedding(task_indices) task_embeddings = self.task_embedding(task_indices)
@ -527,6 +529,8 @@ class MultiTaskCritic(Critic):
) )
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
# TODO: Optimize the code to be compatible with cudagraphs
# Currently in-place creation of task_indices is not compatible with cudagraphs
task_ids_one_hot = obs[..., -self.num_tasks :] task_ids_one_hot = obs[..., -self.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1) task_indices = torch.argmax(task_ids_one_hot, dim=1)
task_embeddings = self.task_embedding(task_indices) task_embeddings = self.task_embedding(task_indices)

View File

@ -4,6 +4,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist
from tensordict import TensorDict from tensordict import TensorDict
@ -428,13 +429,15 @@ class EmpiricalNormalization(nn.Module):
return self._std.squeeze(0).clone() return self._std.squeeze(0).clone()
@torch.no_grad() @torch.no_grad()
def forward(self, x: torch.Tensor, center: bool = True) -> torch.Tensor: def forward(
self, x: torch.Tensor, center: bool = True, update: bool = True
) -> torch.Tensor:
if x.shape[1:] != self._mean.shape[1:]: if x.shape[1:] != self._mean.shape[1:]:
raise ValueError( raise ValueError(
f"Expected input of shape (*,{self._mean.shape[1:]}), got {x.shape}" f"Expected input of shape (*,{self._mean.shape[1:]}), got {x.shape}"
) )
if self.training: if self.training and update:
self.update(x) self.update(x)
if center: if center:
return (x - self._mean) / (self._std + self.eps) return (x - self._mean) / (self._std + self.eps)
@ -446,27 +449,46 @@ class EmpiricalNormalization(nn.Module):
if self.until is not None and self.count >= self.until: if self.until is not None and self.count >= self.until:
return return
batch_size = x.shape[0] if dist.is_available() and dist.is_initialized():
batch_mean = torch.mean(x, dim=0, keepdim=True) # Calculate global batch size arithmetically
local_batch_size = x.shape[0]
world_size = dist.get_world_size()
global_batch_size = world_size * local_batch_size
# Update count # Calculate the stats
new_count = self.count + batch_size x_shifted = x - self._mean
local_sum_shifted = torch.sum(x_shifted, dim=0, keepdim=True)
local_sum_sq_shifted = torch.sum(x_shifted.pow(2), dim=0, keepdim=True)
# Sync the stats across all processes
stats_to_sync = torch.cat([local_sum_shifted, local_sum_sq_shifted], dim=0)
dist.all_reduce(stats_to_sync, op=dist.ReduceOp.SUM)
global_sum_shifted, global_sum_sq_shifted = stats_to_sync
# Calculate the mean and variance of the global batch
batch_mean_shifted = global_sum_shifted / global_batch_size
batch_var = (
global_sum_sq_shifted / global_batch_size - batch_mean_shifted.pow(2)
)
batch_mean = batch_mean_shifted + self._mean
else:
global_batch_size = x.shape[0]
batch_mean = torch.mean(x, dim=0, keepdim=True)
batch_var = torch.var(x, dim=0, keepdim=True, unbiased=False)
new_count = self.count + global_batch_size
# Update mean # Update mean
delta = batch_mean - self._mean delta = batch_mean - self._mean
self._mean += (batch_size / new_count) * delta self._mean.copy_(self._mean + delta * (global_batch_size / new_count))
# Compute batch variance # Update variance
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True) delta2 = batch_mean - self._mean
delta2 = batch_mean - self._mean # uses updated mean m_a = self._var * self.count
m_b = batch_var * global_batch_size
# Parallel variance update (works even when previous count == 0) M2 = m_a + m_b + delta2.pow(2) * (self.count * global_batch_size / new_count)
m_a = self._var * self.count # previous aggregated M2
m_b = batch_var * batch_size
M2 = m_a + m_b + delta2.pow(2) * (self.count * batch_size / new_count)
self._var.copy_(M2 / new_count) self._var.copy_(M2 / new_count)
# Update std and count in-place to avoid expensive __setattr__
self._std.copy_(self._var.sqrt()) self._std.copy_(self._var.sqrt())
self.count.copy_(new_count) self.count.copy_(new_count)
@ -507,7 +529,13 @@ class RewardNormalizer(nn.Module):
): ):
self.G = self.gamma * (1 - dones) * self.G + rewards self.G = self.gamma * (1 - dones) * self.G + rewards
self.G_rms.update(self.G.view(-1, 1)) self.G_rms.update(self.G.view(-1, 1))
self.G_r_max = max(self.G_r_max, max(abs(self.G)))
local_max = torch.max(torch.abs(self.G))
if dist.is_available() and dist.is_initialized():
dist.all_reduce(local_max, op=dist.ReduceOp.MAX)
self.G_r_max = max(self.G_r_max, local_max)
def forward(self, rewards: torch.Tensor) -> torch.Tensor: def forward(self, rewards: torch.Tensor) -> torch.Tensor:
return self._scale_reward(rewards) return self._scale_reward(rewards)
@ -608,7 +636,7 @@ class PerTaskEmpiricalNormalization(nn.Module):
task_mean = self._mean[task_id] task_mean = self._mean[task_id]
batch_mean = torch.mean(x_task, dim=0) batch_mean = torch.mean(x_task, dim=0)
delta = batch_mean - task_mean delta = batch_mean - task_mean
self._mean[task_id] = task_mean + (batch_size / new_count) * delta self._mean[task_id].copy_(task_mean + (batch_size / new_count) * delta)
# Update variance using Chan's parallel algorithm # Update variance using Chan's parallel algorithm
if old_count > 0: if old_count > 0:
@ -616,13 +644,13 @@ class PerTaskEmpiricalNormalization(nn.Module):
m_a = self._var[task_id] * old_count m_a = self._var[task_id] * old_count
m_b = batch_var * batch_size m_b = batch_var * batch_size
M2 = m_a + m_b + (delta**2) * (old_count * batch_size / new_count) M2 = m_a + m_b + (delta**2) * (old_count * batch_size / new_count)
self._var[task_id] = M2 / new_count self._var[task_id].copy_(M2 / new_count)
else: else:
# For the first batch of this task # For the first batch of this task
self._var[task_id] = torch.var(x_task, dim=0, unbiased=False) self._var[task_id].copy_(torch.var(x_task, dim=0, unbiased=False))
self._std[task_id] = torch.sqrt(self._var[task_id]) self._std[task_id].copy_(torch.sqrt(self._var[task_id]))
self.count[task_id] = new_count self.count[task_id].copy_(new_count)
class PerTaskRewardNormalizer(nn.Module): class PerTaskRewardNormalizer(nn.Module):
@ -735,11 +763,18 @@ def save_params(
save_path, save_path,
): ):
"""Save model parameters and training configuration to disk.""" """Save model parameters and training configuration to disk."""
def get_ddp_state_dict(model):
"""Get state dict from model, handling DDP wrapper if present."""
if hasattr(model, "module"):
return model.module.state_dict()
return model.state_dict()
os.makedirs(os.path.dirname(save_path), exist_ok=True) os.makedirs(os.path.dirname(save_path), exist_ok=True)
save_dict = { save_dict = {
"actor_state_dict": cpu_state(actor.state_dict()), "actor_state_dict": cpu_state(get_ddp_state_dict(actor)),
"qnet_state_dict": cpu_state(qnet.state_dict()), "qnet_state_dict": cpu_state(get_ddp_state_dict(qnet)),
"qnet_target_state_dict": cpu_state(qnet_target.state_dict()), "qnet_target_state_dict": cpu_state(get_ddp_state_dict(qnet_target)),
"obs_normalizer_state": ( "obs_normalizer_state": (
cpu_state(obs_normalizer.state_dict()) cpu_state(obs_normalizer.state_dict())
if hasattr(obs_normalizer, "state_dict") if hasattr(obs_normalizer, "state_dict")
@ -755,3 +790,24 @@ def save_params(
} }
torch.save(save_dict, save_path, _use_new_zipfile_serialization=True) torch.save(save_dict, save_path, _use_new_zipfile_serialization=True)
print(f"Saved parameters and configuration to {save_path}") print(f"Saved parameters and configuration to {save_path}")
def get_ddp_state_dict(model):
"""Get state dict from model, handling DDP wrapper if present."""
if hasattr(model, "module"):
return model.module.state_dict()
return model.state_dict()
def load_ddp_state_dict(model, state_dict):
"""Load state dict into model, handling DDP wrapper if present."""
if hasattr(model, "module"):
model.module.load_state_dict(state_dict)
else:
model.load_state_dict(state_dict)
@torch.no_grad()
def mark_step():
# call this once per iteration *before* any compiled function
torch.compiler.cudagraph_mark_step_begin()

View File

@ -303,6 +303,7 @@ class MTBenchArgs(BaseArgs):
num_eval_envs: int = 4096 num_eval_envs: int = 4096
gamma: float = 0.97 gamma: float = 0.97
num_steps: int = 8 num_steps: int = 8
compile_mode: str = "default" # Multi-task training is not compatible with cudagraphs
@dataclass @dataclass
@ -313,6 +314,7 @@ class MetaWorldMT10Args(MTBenchArgs):
num_eval_envs: int = 4096 num_eval_envs: int = 4096
num_steps: int = 8 num_steps: int = 8
gamma: float = 0.97 gamma: float = 0.97
compile_mode: str = "default" # Multi-task training is not compatible with cudagraphs
@dataclass @dataclass
@ -324,6 +326,7 @@ class MetaWorldMT50Args(MTBenchArgs):
num_eval_envs: int = 8192 num_eval_envs: int = 8192
num_steps: int = 8 num_steps: int = 8
gamma: float = 0.99 gamma: float = 0.99
compile_mode: str = "default" # Multi-task training is not compatible with cudagraphs
@dataclass @dataclass

View File

@ -38,6 +38,7 @@ from fast_td3_utils import (
PerTaskRewardNormalizer, PerTaskRewardNormalizer,
SimpleReplayBuffer, SimpleReplayBuffer,
save_params, save_params,
mark_step,
) )
from hyperparams import get_args from hyperparams import get_args
@ -49,12 +50,6 @@ except ImportError:
pass pass
@torch.no_grad()
def mark_step():
# call this once per iteration *before* any compiled function
torch.compiler.cudagraph_mark_step_begin()
def main(): def main():
args = get_args() args = get_args()
print(args) print(args)
@ -313,7 +308,6 @@ def main():
noise_clip = args.noise_clip noise_clip = args.noise_clip
def evaluate(): def evaluate():
obs_normalizer.eval()
num_eval_envs = eval_envs.num_envs num_eval_envs = eval_envs.num_envs
episode_returns = torch.zeros(num_eval_envs, device=device) episode_returns = torch.zeros(num_eval_envs, device=device)
episode_lengths = torch.zeros(num_eval_envs, device=device) episode_lengths = torch.zeros(num_eval_envs, device=device)
@ -329,7 +323,7 @@ def main():
with torch.no_grad(), autocast( with torch.no_grad(), autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
): ):
obs = normalize_obs(obs) obs = normalize_obs(obs, update=False)
actions = actor(obs) actions = actor(obs)
next_obs, rewards, dones, infos = eval_envs.step(actions.float()) next_obs, rewards, dones, infos = eval_envs.step(actions.float())
@ -352,12 +346,9 @@ def main():
break break
obs = next_obs obs = next_obs
obs_normalizer.train()
return episode_returns.mean().item(), episode_lengths.mean().item() return episode_returns.mean().item(), episode_lengths.mean().item()
def render_with_rollout(): def render_with_rollout():
obs_normalizer.eval()
# Quick rollout for rendering # Quick rollout for rendering
if env_type == "humanoid_bench": if env_type == "humanoid_bench":
obs = render_env.reset() obs = render_env.reset()
@ -374,7 +365,7 @@ def main():
with torch.no_grad(), autocast( with torch.no_grad(), autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
): ):
obs = normalize_obs(obs) obs = normalize_obs(obs, update=False)
actions = actor(obs) actions = actor(obs)
next_obs, _, done, _ = render_env.step(actions.float()) next_obs, _, done, _ = render_env.step(actions.float())
if env_type == "mujoco_playground": if env_type == "mujoco_playground":
@ -390,8 +381,6 @@ def main():
if env_type == "mujoco_playground": if env_type == "mujoco_playground":
renders = render_env.render_trajectory(renders) renders = render_env.render_trajectory(renders)
obs_normalizer.train()
return renders return renders
def update_main(data, logs_dict): def update_main(data, logs_dict):
@ -473,7 +462,6 @@ def main():
critic_grad_norm = torch.tensor(0.0, device=device) critic_grad_norm = torch.tensor(0.0, device=device)
scaler.step(q_optimizer) scaler.step(q_optimizer)
scaler.update() scaler.update()
q_scheduler.step()
logs_dict["critic_grad_norm"] = critic_grad_norm.detach() logs_dict["critic_grad_norm"] = critic_grad_norm.detach()
logs_dict["qf_loss"] = qf_loss.detach() logs_dict["qf_loss"] = qf_loss.detach()
@ -512,7 +500,6 @@ def main():
actor_grad_norm = torch.tensor(0.0, device=device) actor_grad_norm = torch.tensor(0.0, device=device)
scaler.step(actor_optimizer) scaler.step(actor_optimizer)
scaler.update() scaler.update()
actor_scheduler.step()
logs_dict["actor_grad_norm"] = actor_grad_norm.detach() logs_dict["actor_grad_norm"] = actor_grad_norm.detach()
logs_dict["actor_loss"] = actor_loss.detach() logs_dict["actor_loss"] = actor_loss.detach()
return logs_dict return logs_dict
@ -529,16 +516,12 @@ def main():
compile_mode = args.compile_mode compile_mode = args.compile_mode
update_main = torch.compile(update_main, mode=compile_mode) update_main = torch.compile(update_main, mode=compile_mode)
update_pol = torch.compile(update_pol, mode=compile_mode) update_pol = torch.compile(update_pol, mode=compile_mode)
policy = torch.compile(policy, mode=compile_mode) policy = torch.compile(policy, mode=None)
normalize_obs = torch.compile(obs_normalizer.forward, mode=compile_mode) normalize_obs = torch.compile(obs_normalizer.forward, mode=None)
normalize_critic_obs = torch.compile( normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=None)
critic_obs_normalizer.forward, mode=compile_mode
)
if args.reward_normalization: if args.reward_normalization:
update_stats = torch.compile( update_stats = torch.compile(reward_normalizer.update_stats, mode=None)
reward_normalizer.update_stats, mode=compile_mode normalize_reward = torch.compile(reward_normalizer.forward, mode=None)
)
normalize_reward = torch.compile(reward_normalizer.forward, mode=compile_mode)
else: else:
normalize_obs = obs_normalizer.forward normalize_obs = obs_normalizer.forward
normalize_critic_obs = critic_obs_normalizer.forward normalize_critic_obs = critic_obs_normalizer.forward
@ -637,10 +620,9 @@ def main():
if envs.asymmetric_obs: if envs.asymmetric_obs:
critic_obs = next_critic_obs critic_obs = next_critic_obs
batch_size = args.batch_size // args.num_envs
if global_step > args.learning_starts: if global_step > args.learning_starts:
for i in range(args.num_updates): for i in range(args.num_updates):
data = rb.sample(batch_size) data = rb.sample(max(1, args.batch_size // args.num_envs))
data["observations"] = normalize_obs(data["observations"]) data["observations"] = normalize_obs(data["observations"])
data["next"]["observations"] = normalize_obs( data["next"]["observations"] = normalize_obs(
data["next"]["observations"] data["next"]["observations"]
@ -702,19 +684,14 @@ def main():
and global_step % args.render_interval == 0 and global_step % args.render_interval == 0
): ):
renders = render_with_rollout() renders = render_with_rollout()
if args.use_wandb: render_video = wandb.Video(
wandb.log( np.array(renders).transpose(
{ 0, 3, 1, 2
"render_video": wandb.Video( ), # Convert to (T, C, H, W) format
np.array(renders).transpose( fps=30,
0, 3, 1, 2 format="gif",
), # Convert to (T, C, H, W) format )
fps=30, logs["render_video"] = render_video
format="gif",
)
},
step=global_step,
)
if args.use_wandb: if args.use_wandb:
wandb.log( wandb.log(
{ {
@ -745,6 +722,8 @@ def main():
) )
global_step += 1 global_step += 1
actor_scheduler.step()
q_scheduler.step()
pbar.update(1) pbar.update(1)
save_params( save_params(

810
fast_td3/train_multigpu.py Normal file
View File

@ -0,0 +1,810 @@
import os
import sys
os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
if sys.platform != "darwin":
os.environ["MUJOCO_GL"] = "egl"
else:
os.environ["MUJOCO_GL"] = "glfw"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest"
import random
import time
import math
import tqdm
import wandb
import numpy as np
try:
# Required for avoiding IsaacGym import error
import isaacgym
except ImportError:
pass
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from tensordict import TensorDict
from fast_td3_utils import (
EmpiricalNormalization,
RewardNormalizer,
PerTaskRewardNormalizer,
SimpleReplayBuffer,
save_params,
get_ddp_state_dict,
load_ddp_state_dict,
mark_step,
)
from hyperparams import get_args
torch.set_float32_matmul_precision("high")
try:
import jax.numpy as jnp
except ImportError:
pass
def setup_distributed(rank: int, world_size: int):
os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
is_distributed = world_size > 1
if is_distributed:
print(
f"Initializing distributed training with rank {rank}, world size {world_size}"
)
torch.distributed.init_process_group(
backend="nccl", init_method="env://", world_size=world_size, rank=rank
)
torch.cuda.set_device(rank)
return is_distributed
def main(rank: int, world_size: int):
is_distributed = setup_distributed(rank, world_size)
args = get_args()
if rank == 0:
print(args)
run_name = f"{args.env_name}__{args.exp_name}__{args.seed}"
amp_enabled = args.amp and args.cuda and torch.cuda.is_available()
amp_device_type = (
f"cuda:{rank}"
if args.cuda and torch.cuda.is_available()
else "mps" if args.cuda and torch.backends.mps.is_available() else "cpu"
)
amp_dtype = torch.bfloat16 if args.amp_dtype == "bf16" else torch.float16
scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16)
if args.use_wandb and rank == 0:
wandb.init(
project=args.project,
name=run_name,
config=vars(args),
save_code=True,
)
# Use different seeds per rank to avoid synchronization issues
random.seed(args.seed + rank)
np.random.seed(args.seed + rank)
torch.manual_seed(args.seed + rank)
torch.backends.cudnn.deterministic = args.torch_deterministic
if not args.cuda:
device = torch.device("cpu")
else:
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
elif torch.backends.mps.is_available():
device = torch.device(f"mps:{rank}")
else:
raise ValueError("No GPU available")
print(f"Using device: {device}")
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
from environments.humanoid_bench_env import HumanoidBenchEnv
env_type = "humanoid_bench"
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
eval_envs = envs
render_env = HumanoidBenchEnv(
args.env_name, 1, render_mode="rgb_array", device=device
)
elif args.env_name.startswith("Isaac-"):
from environments.isaaclab_env import IsaacLabEnv
env_type = "isaaclab"
envs = IsaacLabEnv(
args.env_name,
f"cuda:{rank}",
args.num_envs,
args.seed + rank,
action_bounds=args.action_bounds,
)
eval_envs = envs
render_env = envs
elif args.env_name.startswith("MTBench-"):
from environments.mtbench_env import MTBenchEnv
env_name = "-".join(args.env_name.split("-")[1:])
env_type = "mtbench"
envs = MTBenchEnv(env_name, rank, args.num_envs, args.seed + rank)
eval_envs = envs
render_env = envs
else:
from environments.mujoco_playground_env import make_env
# TODO: Check if re-using same envs for eval could reduce memory usage
env_type = "mujoco_playground"
envs, eval_envs, render_env = make_env(
args.env_name,
args.seed + rank,
args.num_envs,
args.num_eval_envs,
rank,
use_tuned_reward=args.use_tuned_reward,
use_domain_randomization=args.use_domain_randomization,
use_push_randomization=args.use_push_randomization,
)
n_act = envs.num_actions
n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]
if envs.asymmetric_obs:
n_critic_obs = (
envs.num_privileged_obs
if type(envs.num_privileged_obs) == int
else envs.num_privileged_obs[0]
)
else:
n_critic_obs = n_obs
action_low, action_high = -1.0, 1.0
if args.obs_normalization:
obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device)
critic_obs_normalizer = EmpiricalNormalization(
shape=n_critic_obs, device=device
)
else:
obs_normalizer = nn.Identity()
critic_obs_normalizer = nn.Identity()
if args.reward_normalization:
if env_type in ["mtbench"]:
reward_normalizer = PerTaskRewardNormalizer(
num_tasks=envs.num_tasks,
gamma=args.gamma,
device=device,
g_max=min(abs(args.v_min), abs(args.v_max)),
)
else:
reward_normalizer = RewardNormalizer(
gamma=args.gamma,
device=device,
g_max=min(abs(args.v_min), abs(args.v_max)),
)
else:
reward_normalizer = nn.Identity()
actor_kwargs = {
"n_obs": n_obs,
"n_act": n_act,
"num_envs": args.num_envs,
"device": device,
"init_scale": args.init_scale,
"hidden_dim": args.actor_hidden_dim,
}
critic_kwargs = {
"n_obs": n_critic_obs,
"n_act": n_act,
"num_atoms": args.num_atoms,
"v_min": args.v_min,
"v_max": args.v_max,
"hidden_dim": args.critic_hidden_dim,
"device": device,
}
if env_type == "mtbench":
actor_kwargs["n_obs"] = n_obs - envs.num_tasks + args.task_embedding_dim
critic_kwargs["n_obs"] = n_critic_obs - envs.num_tasks + args.task_embedding_dim
actor_kwargs["num_tasks"] = envs.num_tasks
actor_kwargs["task_embedding_dim"] = args.task_embedding_dim
critic_kwargs["num_tasks"] = envs.num_tasks
critic_kwargs["task_embedding_dim"] = args.task_embedding_dim
if args.agent == "fasttd3":
if env_type in ["mtbench"]:
from fast_td3 import MultiTaskActor, MultiTaskCritic
actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic
else:
from fast_td3 import Actor, Critic
actor_cls = Actor
critic_cls = Critic
if rank == 0:
print("Using FastTD3")
elif args.agent == "fasttd3_simbav2":
if env_type in ["mtbench"]:
from fast_td3_simbav2 import MultiTaskActor, MultiTaskCritic
actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic
else:
from fast_td3_simbav2 import Actor, Critic
actor_cls = Actor
critic_cls = Critic
if rank == 0:
print("Using FastTD3 + SimbaV2")
actor_kwargs.pop("init_scale")
actor_kwargs.update(
{
"scaler_init": math.sqrt(2.0 / args.actor_hidden_dim),
"scaler_scale": math.sqrt(2.0 / args.actor_hidden_dim),
"alpha_init": 1.0 / (args.actor_num_blocks + 1),
"alpha_scale": 1.0 / math.sqrt(args.actor_hidden_dim),
"expansion": 4,
"c_shift": 3.0,
"num_blocks": args.actor_num_blocks,
}
)
critic_kwargs.update(
{
"scaler_init": math.sqrt(2.0 / args.critic_hidden_dim),
"scaler_scale": math.sqrt(2.0 / args.critic_hidden_dim),
"alpha_init": 1.0 / (args.critic_num_blocks + 1),
"alpha_scale": 1.0 / math.sqrt(args.critic_hidden_dim),
"num_blocks": args.critic_num_blocks,
"expansion": 4,
"c_shift": 3.0,
}
)
else:
raise ValueError(f"Agent {args.agent} not supported")
actor = actor_cls(**actor_kwargs)
if is_distributed:
actor = DDP(actor, device_ids=[rank])
if env_type in ["mtbench"]:
# Python 3.8 doesn't support 'from_module' in tensordict
policy = actor.module.explore if hasattr(actor, "module") else actor.explore
else:
from tensordict import from_module
actor_detach = actor_cls(**actor_kwargs)
# Copy params to actor_detach without grad
from_module(actor.module if hasattr(actor, "module") else actor).data.to_module(
actor_detach
)
policy = actor_detach.explore
qnet = critic_cls(**critic_kwargs)
if is_distributed:
qnet = DDP(qnet, device_ids=[rank])
qnet_target = critic_cls(**critic_kwargs) # Create a separate instance
qnet_target.load_state_dict(get_ddp_state_dict(qnet))
q_optimizer = optim.AdamW(
list(qnet.parameters()),
lr=torch.tensor(args.critic_learning_rate, device=device),
weight_decay=args.weight_decay,
)
actor_optimizer = optim.AdamW(
list(actor.parameters()),
lr=torch.tensor(args.actor_learning_rate, device=device),
weight_decay=args.weight_decay,
)
# Add learning rate schedulers
q_scheduler = optim.lr_scheduler.CosineAnnealingLR(
q_optimizer,
T_max=args.total_timesteps,
eta_min=torch.tensor(args.critic_learning_rate_end, device=device),
)
actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(
actor_optimizer,
T_max=args.total_timesteps,
eta_min=torch.tensor(args.actor_learning_rate_end, device=device),
)
rb = SimpleReplayBuffer(
n_env=args.num_envs,
buffer_size=args.buffer_size,
n_obs=n_obs,
n_act=n_act,
n_critic_obs=n_critic_obs,
asymmetric_obs=envs.asymmetric_obs,
playground_mode=env_type == "mujoco_playground",
n_steps=args.num_steps,
gamma=args.gamma,
device=device,
)
policy_noise = args.policy_noise
noise_clip = args.noise_clip
def evaluate():
num_eval_envs = eval_envs.num_envs
episode_returns = torch.zeros(num_eval_envs, device=device)
episode_lengths = torch.zeros(num_eval_envs, device=device)
done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device)
if env_type == "isaaclab":
obs = eval_envs.reset(random_start_init=False)
else:
obs = eval_envs.reset()
# Run for a fixed number of steps
for i in range(eval_envs.max_episode_steps):
with torch.no_grad(), autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
obs = normalize_obs(obs, update=False)
actions = actor(obs)
next_obs, rewards, dones, infos = eval_envs.step(actions.float())
if env_type == "mtbench":
# We only report success rate in MTBench evaluation
rewards = (
infos["episode"]["success"].float() if "episode" in infos else 0.0
)
episode_returns = torch.where(
~done_masks, episode_returns + rewards, episode_returns
)
episode_lengths = torch.where(
~done_masks, episode_lengths + 1, episode_lengths
)
if env_type == "mtbench" and "episode" in infos:
dones = dones | infos["episode"]["success"]
done_masks = torch.logical_or(done_masks, dones)
if done_masks.all():
break
obs = next_obs
return episode_returns.mean(), episode_lengths.mean()
def render_with_rollout():
# Quick rollout for rendering
if env_type == "humanoid_bench":
obs = render_env.reset()
renders = [render_env.render()]
elif env_type in ["isaaclab", "mtbench"]:
raise NotImplementedError(
"We don't support rendering for IsaacLab and MTBench environments"
)
else:
obs = render_env.reset()
render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]])
renders = [render_env.state]
for i in range(render_env.max_episode_steps):
with torch.no_grad(), autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
obs = normalize_obs(obs, update=False)
actions = actor(obs)
next_obs, _, done, _ = render_env.step(actions.float())
if env_type == "mujoco_playground":
render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]])
if i % 2 == 0:
if env_type == "humanoid_bench":
renders.append(render_env.render())
else:
renders.append(render_env.state)
if done.any():
break
obs = next_obs
if env_type == "mujoco_playground":
renders = render_env.render_trajectory(renders)
return renders
def update_main(data, logs_dict):
with autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
observations = data["observations"]
next_observations = data["next"]["observations"]
if envs.asymmetric_obs:
critic_observations = data["critic_observations"]
next_critic_observations = data["next"]["critic_observations"]
else:
critic_observations = observations
next_critic_observations = next_observations
actions = data["actions"]
rewards = data["next"]["rewards"]
dones = data["next"]["dones"].bool()
truncations = data["next"]["truncations"].bool()
if args.disable_bootstrap:
bootstrap = (~dones).float()
else:
bootstrap = (truncations | ~dones).float()
clipped_noise = torch.randn_like(actions)
clipped_noise = clipped_noise.mul(policy_noise).clamp(
-noise_clip, noise_clip
)
next_state_actions = (actor(next_observations) + clipped_noise).clamp(
action_low, action_high
)
discount = args.gamma ** data["next"]["effective_n_steps"]
with torch.no_grad():
qf1_next_target_projected, qf2_next_target_projected = (
qnet_target.projection(
next_critic_observations,
next_state_actions,
rewards,
bootstrap,
discount,
)
)
qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)
qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected)
if args.use_cdq:
qf_next_target_dist = torch.where(
qf1_next_target_value.unsqueeze(1)
< qf2_next_target_value.unsqueeze(1),
qf1_next_target_projected,
qf2_next_target_projected,
)
qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist
else:
qf1_next_target_dist, qf2_next_target_dist = (
qf1_next_target_projected,
qf2_next_target_projected,
)
qf1, qf2 = qnet(critic_observations, actions)
qf1_loss = -torch.sum(
qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1
).mean()
qf2_loss = -torch.sum(
qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1
).mean()
qf_loss = qf1_loss + qf2_loss
q_optimizer.zero_grad(set_to_none=True)
scaler.scale(qf_loss).backward()
scaler.unscale_(q_optimizer)
if args.use_grad_norm_clipping:
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
qnet.parameters(),
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
)
else:
critic_grad_norm = torch.tensor(0.0, device=device)
scaler.step(q_optimizer)
scaler.update()
logs_dict["critic_grad_norm"] = critic_grad_norm.detach()
logs_dict["qf_loss"] = qf_loss.detach()
logs_dict["qf_max"] = qf1_next_target_value.max().detach()
logs_dict["qf_min"] = qf1_next_target_value.min().detach()
return logs_dict
def update_pol(data, logs_dict):
with autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
critic_observations = (
data["critic_observations"]
if envs.asymmetric_obs
else data["observations"]
)
qf1, qf2 = qnet(critic_observations, actor(data["observations"]))
qf1_value = (
qnet.module.get_value(F.softmax(qf1, dim=1))
if hasattr(qnet, "module")
else qnet.get_value(F.softmax(qf1, dim=1))
)
qf2_value = (
qnet.module.get_value(F.softmax(qf2, dim=1))
if hasattr(qnet, "module")
else qnet.get_value(F.softmax(qf2, dim=1))
)
if args.use_cdq:
qf_value = torch.minimum(qf1_value, qf2_value)
else:
qf_value = (qf1_value + qf2_value) / 2.0
actor_loss = -qf_value.mean()
actor_optimizer.zero_grad(set_to_none=True)
scaler.scale(actor_loss).backward()
scaler.unscale_(actor_optimizer)
if args.use_grad_norm_clipping:
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
actor.parameters(),
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
)
else:
actor_grad_norm = torch.tensor(0.0, device=device)
scaler.step(actor_optimizer)
scaler.update()
logs_dict["actor_grad_norm"] = actor_grad_norm.detach()
logs_dict["actor_loss"] = actor_loss.detach()
return logs_dict
@torch.no_grad()
def soft_update(src, tgt, tau: float):
# Handle DDP module by accessing .module attribute
src_module = src.module if hasattr(src, "module") else src
tgt_module = tgt.module if hasattr(tgt, "module") else tgt
src_ps = [p.data for p in src_module.parameters()]
tgt_ps = [p.data for p in tgt_module.parameters()]
torch._foreach_mul_(tgt_ps, 1.0 - tau)
torch._foreach_add_(tgt_ps, src_ps, alpha=tau)
if args.compile:
compile_mode = args.compile_mode
update_main = torch.compile(update_main, mode=compile_mode)
update_pol = torch.compile(update_pol, mode=compile_mode)
policy = torch.compile(policy, mode=None)
normalize_obs = torch.compile(obs_normalizer.forward, mode=None)
normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=None)
if args.reward_normalization:
update_stats = torch.compile(reward_normalizer.update_stats, mode=None)
normalize_reward = torch.compile(reward_normalizer.forward, mode=None)
else:
normalize_obs = obs_normalizer.forward
normalize_critic_obs = critic_obs_normalizer.forward
if args.reward_normalization:
update_stats = reward_normalizer.update_stats
normalize_reward = reward_normalizer.forward
if envs.asymmetric_obs:
obs, critic_obs = envs.reset_with_critic_obs()
critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float)
else:
obs = envs.reset()
if args.checkpoint_path:
# Load checkpoint if specified
torch_checkpoint = torch.load(
f"{args.checkpoint_path}", map_location=device, weights_only=False
)
load_ddp_state_dict(actor, torch_checkpoint["actor_state_dict"])
if torch_checkpoint["obs_normalizer_state"] is not None:
obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"])
if torch_checkpoint["critic_obs_normalizer_state"] is not None:
critic_obs_normalizer.load_state_dict(
torch_checkpoint["critic_obs_normalizer_state"]
)
load_ddp_state_dict(qnet, torch_checkpoint["qnet_state_dict"])
qnet_target.load_state_dict(torch_checkpoint["qnet_target_state_dict"])
global_step = torch_checkpoint["global_step"]
else:
global_step = 0
dones = None
pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step)
start_time = None
desc = ""
while global_step < args.total_timesteps:
mark_step()
logs_dict = TensorDict()
if (
start_time is None
and global_step >= args.measure_burnin + args.learning_starts
):
start_time = time.time()
measure_burnin = global_step
with torch.no_grad(), autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
norm_obs = normalize_obs(obs)
actions = policy(obs=norm_obs, dones=dones)
next_obs, rewards, dones, infos = envs.step(actions.float())
truncations = infos["time_outs"]
if args.reward_normalization:
if env_type == "mtbench":
task_ids_one_hot = obs[..., -envs.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
update_stats(rewards, dones.float(), task_ids=task_indices)
else:
update_stats(rewards, dones.float())
if envs.asymmetric_obs:
next_critic_obs = infos["observations"]["critic"]
# Compute 'true' next_obs and next_critic_obs for saving
true_next_obs = torch.where(
dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs
)
if envs.asymmetric_obs:
true_next_critic_obs = torch.where(
dones[:, None] > 0,
infos["observations"]["raw"]["critic_obs"],
next_critic_obs,
)
transition = TensorDict(
{
"observations": obs,
"actions": torch.as_tensor(actions, device=device, dtype=torch.float),
"next": {
"observations": true_next_obs,
"rewards": torch.as_tensor(
rewards, device=device, dtype=torch.float
),
"truncations": truncations.long(),
"dones": dones.long(),
},
},
batch_size=(envs.num_envs,),
device=device,
)
if envs.asymmetric_obs:
transition["critic_observations"] = critic_obs
transition["next"]["critic_observations"] = true_next_critic_obs
rb.extend(transition)
obs = next_obs
if envs.asymmetric_obs:
critic_obs = next_critic_obs
if global_step > args.learning_starts:
for i in range(args.num_updates):
data = rb.sample(max(1, args.batch_size // args.num_envs))
data["observations"] = normalize_obs(data["observations"])
data["next"]["observations"] = normalize_obs(
data["next"]["observations"]
)
if envs.asymmetric_obs:
data["critic_observations"] = normalize_critic_obs(
data["critic_observations"]
)
data["next"]["critic_observations"] = normalize_critic_obs(
data["next"]["critic_observations"]
)
raw_rewards = data["next"]["rewards"]
if env_type in ["mtbench"] and args.reward_normalization:
# Multi-task reward normalization
task_ids_one_hot = data["observations"][..., -envs.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
data["next"]["rewards"] = normalize_reward(
raw_rewards, task_ids=task_indices
)
else:
data["next"]["rewards"] = normalize_reward(raw_rewards)
logs_dict = update_main(data, logs_dict)
if args.num_updates > 1:
if i % args.policy_frequency == 1:
logs_dict = update_pol(data, logs_dict)
else:
if global_step % args.policy_frequency == 0:
logs_dict = update_pol(data, logs_dict)
soft_update(qnet, qnet_target, args.tau)
if global_step % 100 == 0 and start_time is not None:
speed = (global_step - measure_burnin) / (time.time() - start_time)
if rank == 0:
pbar.set_description(f"{speed: 4.4f} sps, " + desc)
with torch.no_grad():
logs = {
"actor_loss": logs_dict["actor_loss"].mean(),
"qf_loss": logs_dict["qf_loss"].mean(),
"qf_max": logs_dict["qf_max"].mean(),
"qf_min": logs_dict["qf_min"].mean(),
"actor_grad_norm": logs_dict["actor_grad_norm"].mean(),
"critic_grad_norm": logs_dict["critic_grad_norm"].mean(),
"env_rewards": rewards.mean(),
"buffer_rewards": raw_rewards.mean(),
}
if args.eval_interval > 0 and global_step % args.eval_interval == 0:
local_eval_avg_return, local_eval_avg_length = evaluate()
eval_results = torch.tensor(
[local_eval_avg_return, local_eval_avg_length],
device=device,
)
if is_distributed:
torch.distributed.all_reduce(
eval_results, op=torch.distributed.ReduceOp.AVG
)
if rank == 0:
global_avg_return = eval_results[0].item()
global_avg_length = eval_results[1].item()
print(
f"Evaluating at global step {global_step}: Avg Return={global_avg_return:.2f}"
)
logs["eval_avg_return"] = global_avg_return
logs["eval_avg_length"] = global_avg_length
if env_type in ["humanoid_bench", "isaaclab", "mtbench"]:
# NOTE: Hacky way of evaluating performance, but just works
obs = envs.reset()
if (
args.render_interval > 0
and global_step % args.render_interval == 0
):
renders = render_with_rollout()
render_video = wandb.Video(
np.array(renders).transpose(
0, 3, 1, 2
), # Convert to (T, C, H, W) format
fps=30,
format="gif",
)
logs["render_video"] = render_video
if args.use_wandb and rank == 0:
wandb.log(
{
"speed": speed,
"frame": global_step * args.num_envs,
"critic_lr": q_scheduler.get_last_lr()[0],
"actor_lr": actor_scheduler.get_last_lr()[0],
**logs,
},
step=global_step,
)
if (
args.save_interval > 0
and global_step > 0
and global_step % args.save_interval == 0
and rank == 0
):
print(f"Saving model at global step {global_step}")
save_params(
global_step,
actor,
qnet,
qnet_target,
obs_normalizer,
critic_obs_normalizer,
args,
f"models/{run_name}_{global_step}.pt",
)
global_step += 1
actor_scheduler.step()
q_scheduler.step()
if rank == 0:
pbar.update(1)
save_params(
global_step,
actor,
qnet,
qnet_target,
obs_normalizer,
critic_obs_normalizer,
args,
f"models/{run_name}_final.pt",
)
# Cleanup distributed training
if is_distributed:
torch.distributed.destroy_process_group()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size,), nprocs=world_size)