diff --git a/README.md b/README.md index f5993b8..054b848 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,9 @@ For more information, please see our [project webpage](https://younggyo.me/fast_ ## ❗ Updates + +- **[Jul/02/2025]** Optimized codebase to speed up training around 10-30% when using a single RTX 4090 GPU. + - **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/). - **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases. @@ -237,6 +240,7 @@ We used a single Nvidia A100 80GB GPU for all experiments. Here are some remarks - If the agent is completely stuck or much worse than your expectation, try using `num_steps=3` or disabling `use_cdq`. - For tasks that have penalty reward terms (e.g., torques, energy, action_rate, ..), consider lowering them for initial experiments, and tune the values. In some cases, curriculum learning with lower penalty terms followed by fine-tuning with stronger terms is effective. - When you encounter out-of-memory error with your GPU, our recommendation for reducing GPU usage is (i) smaller `buffer_size`, (ii) smaller `batch_size`, and then (iii) smaller `num_envs`. Because our codebase is assigning the whole replay buffer in GPU to reduce CPU-GPU transfer bottleneck, it usually has the largest GPU consumption, but usually less harmful to reduce. +- Consider using `--compile_mode max-autotune` if you plan to run for many training steps. This may speed up training by up to 10% at the cost of a few additional minutes of heavy compilation. ## 🛝 Playing with the FastTD3 training diff --git a/fast_td3/fast_td3.py b/fast_td3/fast_td3.py index b7d4ff1..fcd6c37 100644 --- a/fast_td3/fast_td3.py +++ b/fast_td3/fast_td3.py @@ -212,7 +212,9 @@ class Actor(nn.Module): # Update only the noise scales for environments that are done dones_view = dones.view(-1, 1) > 0 - self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales) + self.noise_scales.copy_( + torch.where(dones_view, new_scales, self.noise_scales) + ) act = self(obs) if deterministic: diff --git a/fast_td3/fast_td3_simbav2.py b/fast_td3/fast_td3_simbav2.py index f0b17eb..80ae4ff 100644 --- a/fast_td3/fast_td3_simbav2.py +++ b/fast_td3/fast_td3_simbav2.py @@ -486,7 +486,9 @@ class Actor(nn.Module): # Update only the noise scales for environments that are done dones_view = dones.view(-1, 1) > 0 - self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales) + self.noise_scales.copy_( + torch.where(dones_view, new_scales, self.noise_scales) + ) act = self(obs) if deterministic: diff --git a/fast_td3/fast_td3_utils.py b/fast_td3/fast_td3_utils.py index 336ba1c..f16a4c0 100644 --- a/fast_td3/fast_td3_utils.py +++ b/fast_td3/fast_td3_utils.py @@ -85,6 +85,7 @@ class SimpleReplayBuffer(nn.Module): ) self.ptr = 0 + @torch.no_grad() def extend( self, tensor_dict: TensorDict, @@ -119,6 +120,7 @@ class SimpleReplayBuffer(nn.Module): self.next_critic_observations[:, ptr] = next_critic_observations self.ptr += 1 + @torch.no_grad() def sample(self, batch_size: int): # we will sample n_env * batch_size transitions @@ -425,6 +427,7 @@ class EmpiricalNormalization(nn.Module): def std(self): return self._std.squeeze(0).clone() + @torch.no_grad() def forward(self, x: torch.Tensor, center: bool = True) -> torch.Tensor: if x.shape[1:] != self._mean.shape[1:]: raise ValueError( @@ -453,21 +456,19 @@ class EmpiricalNormalization(nn.Module): delta = batch_mean - self._mean self._mean += (batch_size / new_count) * delta - # Update variance using Chan's parallel algorithm - # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm - if self.count > 0: # Ensure we're not dividing by zero - batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True) - delta2 = batch_mean - self._mean - m_a = self._var * self.count - m_b = batch_var * batch_size - M2 = m_a + m_b + (delta2**2) * (self.count * batch_size / new_count) - self._var = M2 / new_count - else: - # For first batch, just use batch variance - self._var = torch.mean((x - self._mean) ** 2, dim=0, keepdim=True) + # Compute batch variance + batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True) + delta2 = batch_mean - self._mean # uses updated mean - self._std = torch.sqrt(self._var) - self.count = new_count + # Parallel variance update (works even when previous count == 0) + m_a = self._var * self.count # previous aggregated M2 + m_b = batch_var * batch_size + M2 = m_a + m_b + delta2.pow(2) * (self.count * batch_size / new_count) + self._var.copy_(M2 / new_count) + + # Update std and count in-place to avoid expensive __setattr__ + self._std.copy_(self._var.sqrt()) + self.count.copy_(new_count) @torch.jit.unused def inverse(self, y): diff --git a/fast_td3/hyperparams.py b/fast_td3/hyperparams.py index 04c2d82..8605322 100644 --- a/fast_td3/hyperparams.py +++ b/fast_td3/hyperparams.py @@ -92,10 +92,14 @@ class BaseArgs: """the interval to render the model""" compile: bool = True """whether to use torch.compile.""" + compile_mode: str = "reduce-overhead" + """the mode of torch.compile.""" obs_normalization: bool = True """whether to enable observation normalization""" reward_normalization: bool = False """whether to enable reward normalization""" + use_grad_norm_clipping: bool = False + """whether to use gradient norm clipping.""" max_grad_norm: float = 0.0 """the maximum gradient norm""" amp: bool = True diff --git a/fast_td3/train.py b/fast_td3/train.py index e20d057..ab4e689 100644 --- a/fast_td3/train.py +++ b/fast_td3/train.py @@ -49,6 +49,12 @@ except ImportError: pass +@torch.no_grad() +def mark_step(): + # call this once per iteration *before* any compiled function + torch.compiler.cudagraph_mark_step_begin() + + def main(): args = get_args() print(args) @@ -458,10 +464,13 @@ def main(): scaler.scale(qf_loss).backward() scaler.unscale_(q_optimizer) - critic_grad_norm = torch.nn.utils.clip_grad_norm_( - qnet.parameters(), - max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), - ) + if args.use_grad_norm_clipping: + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + qnet.parameters(), + max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), + ) + else: + critic_grad_norm = torch.tensor(0.0, device=device) scaler.step(q_optimizer) scaler.update() q_scheduler.step() @@ -494,10 +503,13 @@ def main(): actor_optimizer.zero_grad(set_to_none=True) scaler.scale(actor_loss).backward() scaler.unscale_(actor_optimizer) - actor_grad_norm = torch.nn.utils.clip_grad_norm_( - actor.parameters(), - max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), - ) + if args.use_grad_norm_clipping: + actor_grad_norm = torch.nn.utils.clip_grad_norm_( + actor.parameters(), + max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), + ) + else: + actor_grad_norm = torch.tensor(0.0, device=device) scaler.step(actor_optimizer) scaler.update() actor_scheduler.step() @@ -505,16 +517,28 @@ def main(): logs_dict["actor_loss"] = actor_loss.detach() return logs_dict + @torch.no_grad() + def soft_update(src, tgt, tau: float): + src_ps = [p.data for p in src.parameters()] + tgt_ps = [p.data for p in tgt.parameters()] + + torch._foreach_mul_(tgt_ps, 1.0 - tau) + torch._foreach_add_(tgt_ps, src_ps, alpha=tau) + if args.compile: - mode = None - update_main = torch.compile(update_main, mode=mode) - update_pol = torch.compile(update_pol, mode=mode) - policy = torch.compile(policy, mode=mode) - normalize_obs = torch.compile(obs_normalizer.forward, mode=mode) - normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode) + compile_mode = args.compile_mode + update_main = torch.compile(update_main, mode=compile_mode) + update_pol = torch.compile(update_pol, mode=compile_mode) + policy = torch.compile(policy, mode=compile_mode) + normalize_obs = torch.compile(obs_normalizer.forward, mode=compile_mode) + normalize_critic_obs = torch.compile( + critic_obs_normalizer.forward, mode=compile_mode + ) if args.reward_normalization: - update_stats = torch.compile(reward_normalizer.update_stats, mode=mode) - normalize_reward = torch.compile(reward_normalizer.forward, mode=mode) + update_stats = torch.compile( + reward_normalizer.update_stats, mode=compile_mode + ) + normalize_reward = torch.compile(reward_normalizer.forward, mode=compile_mode) else: normalize_obs = obs_normalizer.forward normalize_critic_obs = critic_obs_normalizer.forward @@ -549,6 +573,7 @@ def main(): desc = "" while global_step < args.total_timesteps: + mark_step() logs_dict = TensorDict() if ( start_time is None @@ -576,7 +601,6 @@ def main(): if envs.asymmetric_obs: next_critic_obs = infos["observations"]["critic"] - # Compute 'true' next_obs and next_critic_obs for saving true_next_obs = torch.where( dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs @@ -587,6 +611,7 @@ def main(): infos["observations"]["raw"]["critic_obs"], next_critic_obs, ) + transition = TensorDict( { "observations": obs, @@ -606,13 +631,12 @@ def main(): if envs.asymmetric_obs: transition["critic_observations"] = critic_obs transition["next"]["critic_observations"] = true_next_critic_obs + rb.extend(transition) obs = next_obs if envs.asymmetric_obs: critic_obs = next_critic_obs - rb.extend(transition) - batch_size = args.batch_size // args.num_envs if global_step > args.learning_starts: for i in range(args.num_updates): @@ -621,6 +645,13 @@ def main(): data["next"]["observations"] = normalize_obs( data["next"]["observations"] ) + if envs.asymmetric_obs: + data["critic_observations"] = normalize_critic_obs( + data["critic_observations"] + ) + data["next"]["critic_observations"] = normalize_critic_obs( + data["next"]["critic_observations"] + ) raw_rewards = data["next"]["rewards"] if env_type in ["mtbench"] and args.reward_normalization: # Multi-task reward normalization @@ -631,13 +662,7 @@ def main(): ) else: data["next"]["rewards"] = normalize_reward(raw_rewards) - if envs.asymmetric_obs: - data["critic_observations"] = normalize_critic_obs( - data["critic_observations"] - ) - data["next"]["critic_observations"] = normalize_critic_obs( - data["next"]["critic_observations"] - ) + logs_dict = update_main(data, logs_dict) if args.num_updates > 1: if i % args.policy_frequency == 1: @@ -646,12 +671,7 @@ def main(): if global_step % args.policy_frequency == 0: logs_dict = update_pol(data, logs_dict) - for param, target_param in zip( - qnet.parameters(), qnet_target.parameters() - ): - target_param.data.copy_( - args.tau * param.data + (1 - args.tau) * target_param.data - ) + soft_update(qnet, qnet_target, args.tau) if global_step % 100 == 0 and start_time is not None: speed = (global_step - measure_burnin) / (time.time() - start_time)