Optimized codebase to speed up training (#20)

- Modified codes to be compatible with torch.compile
- Modified empirical normalizer to use in-place operator to avoid costly __setattr__
- Parallel soft Q-update
- As a default option, we disabled gradient norm clipping as it is quite expensive
This commit is contained in:
Younggyo Seo 2025-07-02 19:39:02 -07:00 committed by GitHub
parent 799624b202
commit c354ead107
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 81 additions and 48 deletions

View File

@ -10,6 +10,9 @@ For more information, please see our [project webpage](https://younggyo.me/fast_
## ❗ Updates
- **[Jul/02/2025]** Optimized codebase to speed up training around 10-30% when using a single RTX 4090 GPU.
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/).
- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases.
@ -237,6 +240,7 @@ We used a single Nvidia A100 80GB GPU for all experiments. Here are some remarks
- If the agent is completely stuck or much worse than your expectation, try using `num_steps=3` or disabling `use_cdq`.
- For tasks that have penalty reward terms (e.g., torques, energy, action_rate, ..), consider lowering them for initial experiments, and tune the values. In some cases, curriculum learning with lower penalty terms followed by fine-tuning with stronger terms is effective.
- When you encounter out-of-memory error with your GPU, our recommendation for reducing GPU usage is (i) smaller `buffer_size`, (ii) smaller `batch_size`, and then (iii) smaller `num_envs`. Because our codebase is assigning the whole replay buffer in GPU to reduce CPU-GPU transfer bottleneck, it usually has the largest GPU consumption, but usually less harmful to reduce.
- Consider using `--compile_mode max-autotune` if you plan to run for many training steps. This may speed up training by up to 10% at the cost of a few additional minutes of heavy compilation.
## 🛝 Playing with the FastTD3 training

View File

@ -212,7 +212,9 @@ class Actor(nn.Module):
# Update only the noise scales for environments that are done
dones_view = dones.view(-1, 1) > 0
self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales)
self.noise_scales.copy_(
torch.where(dones_view, new_scales, self.noise_scales)
)
act = self(obs)
if deterministic:

View File

@ -486,7 +486,9 @@ class Actor(nn.Module):
# Update only the noise scales for environments that are done
dones_view = dones.view(-1, 1) > 0
self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales)
self.noise_scales.copy_(
torch.where(dones_view, new_scales, self.noise_scales)
)
act = self(obs)
if deterministic:

View File

@ -85,6 +85,7 @@ class SimpleReplayBuffer(nn.Module):
)
self.ptr = 0
@torch.no_grad()
def extend(
self,
tensor_dict: TensorDict,
@ -119,6 +120,7 @@ class SimpleReplayBuffer(nn.Module):
self.next_critic_observations[:, ptr] = next_critic_observations
self.ptr += 1
@torch.no_grad()
def sample(self, batch_size: int):
# we will sample n_env * batch_size transitions
@ -425,6 +427,7 @@ class EmpiricalNormalization(nn.Module):
def std(self):
return self._std.squeeze(0).clone()
@torch.no_grad()
def forward(self, x: torch.Tensor, center: bool = True) -> torch.Tensor:
if x.shape[1:] != self._mean.shape[1:]:
raise ValueError(
@ -453,21 +456,19 @@ class EmpiricalNormalization(nn.Module):
delta = batch_mean - self._mean
self._mean += (batch_size / new_count) * delta
# Update variance using Chan's parallel algorithm
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
if self.count > 0: # Ensure we're not dividing by zero
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
delta2 = batch_mean - self._mean
m_a = self._var * self.count
m_b = batch_var * batch_size
M2 = m_a + m_b + (delta2**2) * (self.count * batch_size / new_count)
self._var = M2 / new_count
else:
# For first batch, just use batch variance
self._var = torch.mean((x - self._mean) ** 2, dim=0, keepdim=True)
# Compute batch variance
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
delta2 = batch_mean - self._mean # uses updated mean
self._std = torch.sqrt(self._var)
self.count = new_count
# Parallel variance update (works even when previous count == 0)
m_a = self._var * self.count # previous aggregated M2
m_b = batch_var * batch_size
M2 = m_a + m_b + delta2.pow(2) * (self.count * batch_size / new_count)
self._var.copy_(M2 / new_count)
# Update std and count in-place to avoid expensive __setattr__
self._std.copy_(self._var.sqrt())
self.count.copy_(new_count)
@torch.jit.unused
def inverse(self, y):

View File

@ -92,10 +92,14 @@ class BaseArgs:
"""the interval to render the model"""
compile: bool = True
"""whether to use torch.compile."""
compile_mode: str = "reduce-overhead"
"""the mode of torch.compile."""
obs_normalization: bool = True
"""whether to enable observation normalization"""
reward_normalization: bool = False
"""whether to enable reward normalization"""
use_grad_norm_clipping: bool = False
"""whether to use gradient norm clipping."""
max_grad_norm: float = 0.0
"""the maximum gradient norm"""
amp: bool = True

View File

@ -49,6 +49,12 @@ except ImportError:
pass
@torch.no_grad()
def mark_step():
# call this once per iteration *before* any compiled function
torch.compiler.cudagraph_mark_step_begin()
def main():
args = get_args()
print(args)
@ -458,10 +464,13 @@ def main():
scaler.scale(qf_loss).backward()
scaler.unscale_(q_optimizer)
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
qnet.parameters(),
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
)
if args.use_grad_norm_clipping:
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
qnet.parameters(),
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
)
else:
critic_grad_norm = torch.tensor(0.0, device=device)
scaler.step(q_optimizer)
scaler.update()
q_scheduler.step()
@ -494,10 +503,13 @@ def main():
actor_optimizer.zero_grad(set_to_none=True)
scaler.scale(actor_loss).backward()
scaler.unscale_(actor_optimizer)
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
actor.parameters(),
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
)
if args.use_grad_norm_clipping:
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
actor.parameters(),
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
)
else:
actor_grad_norm = torch.tensor(0.0, device=device)
scaler.step(actor_optimizer)
scaler.update()
actor_scheduler.step()
@ -505,16 +517,28 @@ def main():
logs_dict["actor_loss"] = actor_loss.detach()
return logs_dict
@torch.no_grad()
def soft_update(src, tgt, tau: float):
src_ps = [p.data for p in src.parameters()]
tgt_ps = [p.data for p in tgt.parameters()]
torch._foreach_mul_(tgt_ps, 1.0 - tau)
torch._foreach_add_(tgt_ps, src_ps, alpha=tau)
if args.compile:
mode = None
update_main = torch.compile(update_main, mode=mode)
update_pol = torch.compile(update_pol, mode=mode)
policy = torch.compile(policy, mode=mode)
normalize_obs = torch.compile(obs_normalizer.forward, mode=mode)
normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode)
compile_mode = args.compile_mode
update_main = torch.compile(update_main, mode=compile_mode)
update_pol = torch.compile(update_pol, mode=compile_mode)
policy = torch.compile(policy, mode=compile_mode)
normalize_obs = torch.compile(obs_normalizer.forward, mode=compile_mode)
normalize_critic_obs = torch.compile(
critic_obs_normalizer.forward, mode=compile_mode
)
if args.reward_normalization:
update_stats = torch.compile(reward_normalizer.update_stats, mode=mode)
normalize_reward = torch.compile(reward_normalizer.forward, mode=mode)
update_stats = torch.compile(
reward_normalizer.update_stats, mode=compile_mode
)
normalize_reward = torch.compile(reward_normalizer.forward, mode=compile_mode)
else:
normalize_obs = obs_normalizer.forward
normalize_critic_obs = critic_obs_normalizer.forward
@ -549,6 +573,7 @@ def main():
desc = ""
while global_step < args.total_timesteps:
mark_step()
logs_dict = TensorDict()
if (
start_time is None
@ -576,7 +601,6 @@ def main():
if envs.asymmetric_obs:
next_critic_obs = infos["observations"]["critic"]
# Compute 'true' next_obs and next_critic_obs for saving
true_next_obs = torch.where(
dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs
@ -587,6 +611,7 @@ def main():
infos["observations"]["raw"]["critic_obs"],
next_critic_obs,
)
transition = TensorDict(
{
"observations": obs,
@ -606,13 +631,12 @@ def main():
if envs.asymmetric_obs:
transition["critic_observations"] = critic_obs
transition["next"]["critic_observations"] = true_next_critic_obs
rb.extend(transition)
obs = next_obs
if envs.asymmetric_obs:
critic_obs = next_critic_obs
rb.extend(transition)
batch_size = args.batch_size // args.num_envs
if global_step > args.learning_starts:
for i in range(args.num_updates):
@ -621,6 +645,13 @@ def main():
data["next"]["observations"] = normalize_obs(
data["next"]["observations"]
)
if envs.asymmetric_obs:
data["critic_observations"] = normalize_critic_obs(
data["critic_observations"]
)
data["next"]["critic_observations"] = normalize_critic_obs(
data["next"]["critic_observations"]
)
raw_rewards = data["next"]["rewards"]
if env_type in ["mtbench"] and args.reward_normalization:
# Multi-task reward normalization
@ -631,13 +662,7 @@ def main():
)
else:
data["next"]["rewards"] = normalize_reward(raw_rewards)
if envs.asymmetric_obs:
data["critic_observations"] = normalize_critic_obs(
data["critic_observations"]
)
data["next"]["critic_observations"] = normalize_critic_obs(
data["next"]["critic_observations"]
)
logs_dict = update_main(data, logs_dict)
if args.num_updates > 1:
if i % args.policy_frequency == 1:
@ -646,12 +671,7 @@ def main():
if global_step % args.policy_frequency == 0:
logs_dict = update_pol(data, logs_dict)
for param, target_param in zip(
qnet.parameters(), qnet_target.parameters()
):
target_param.data.copy_(
args.tau * param.data + (1 - args.tau) * target_param.data
)
soft_update(qnet, qnet_target, args.tau)
if global_step % 100 == 0 and start_time is not None:
speed = (global_step - measure_burnin) / (time.time() - start_time)