Optimized codebase to speed up training (#20)
- Modified codes to be compatible with torch.compile - Modified empirical normalizer to use in-place operator to avoid costly __setattr__ - Parallel soft Q-update - As a default option, we disabled gradient norm clipping as it is quite expensive
This commit is contained in:
parent
799624b202
commit
c354ead107
@ -10,6 +10,9 @@ For more information, please see our [project webpage](https://younggyo.me/fast_
|
||||
|
||||
|
||||
## ❗ Updates
|
||||
|
||||
- **[Jul/02/2025]** Optimized codebase to speed up training around 10-30% when using a single RTX 4090 GPU.
|
||||
|
||||
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/).
|
||||
|
||||
- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases.
|
||||
@ -237,6 +240,7 @@ We used a single Nvidia A100 80GB GPU for all experiments. Here are some remarks
|
||||
- If the agent is completely stuck or much worse than your expectation, try using `num_steps=3` or disabling `use_cdq`.
|
||||
- For tasks that have penalty reward terms (e.g., torques, energy, action_rate, ..), consider lowering them for initial experiments, and tune the values. In some cases, curriculum learning with lower penalty terms followed by fine-tuning with stronger terms is effective.
|
||||
- When you encounter out-of-memory error with your GPU, our recommendation for reducing GPU usage is (i) smaller `buffer_size`, (ii) smaller `batch_size`, and then (iii) smaller `num_envs`. Because our codebase is assigning the whole replay buffer in GPU to reduce CPU-GPU transfer bottleneck, it usually has the largest GPU consumption, but usually less harmful to reduce.
|
||||
- Consider using `--compile_mode max-autotune` if you plan to run for many training steps. This may speed up training by up to 10% at the cost of a few additional minutes of heavy compilation.
|
||||
|
||||
## 🛝 Playing with the FastTD3 training
|
||||
|
||||
|
@ -212,7 +212,9 @@ class Actor(nn.Module):
|
||||
|
||||
# Update only the noise scales for environments that are done
|
||||
dones_view = dones.view(-1, 1) > 0
|
||||
self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales)
|
||||
self.noise_scales.copy_(
|
||||
torch.where(dones_view, new_scales, self.noise_scales)
|
||||
)
|
||||
|
||||
act = self(obs)
|
||||
if deterministic:
|
||||
|
@ -486,7 +486,9 @@ class Actor(nn.Module):
|
||||
|
||||
# Update only the noise scales for environments that are done
|
||||
dones_view = dones.view(-1, 1) > 0
|
||||
self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales)
|
||||
self.noise_scales.copy_(
|
||||
torch.where(dones_view, new_scales, self.noise_scales)
|
||||
)
|
||||
|
||||
act = self(obs)
|
||||
if deterministic:
|
||||
|
@ -85,6 +85,7 @@ class SimpleReplayBuffer(nn.Module):
|
||||
)
|
||||
self.ptr = 0
|
||||
|
||||
@torch.no_grad()
|
||||
def extend(
|
||||
self,
|
||||
tensor_dict: TensorDict,
|
||||
@ -119,6 +120,7 @@ class SimpleReplayBuffer(nn.Module):
|
||||
self.next_critic_observations[:, ptr] = next_critic_observations
|
||||
self.ptr += 1
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, batch_size: int):
|
||||
# we will sample n_env * batch_size transitions
|
||||
|
||||
@ -425,6 +427,7 @@ class EmpiricalNormalization(nn.Module):
|
||||
def std(self):
|
||||
return self._std.squeeze(0).clone()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor, center: bool = True) -> torch.Tensor:
|
||||
if x.shape[1:] != self._mean.shape[1:]:
|
||||
raise ValueError(
|
||||
@ -453,21 +456,19 @@ class EmpiricalNormalization(nn.Module):
|
||||
delta = batch_mean - self._mean
|
||||
self._mean += (batch_size / new_count) * delta
|
||||
|
||||
# Update variance using Chan's parallel algorithm
|
||||
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
if self.count > 0: # Ensure we're not dividing by zero
|
||||
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
|
||||
delta2 = batch_mean - self._mean
|
||||
m_a = self._var * self.count
|
||||
m_b = batch_var * batch_size
|
||||
M2 = m_a + m_b + (delta2**2) * (self.count * batch_size / new_count)
|
||||
self._var = M2 / new_count
|
||||
else:
|
||||
# For first batch, just use batch variance
|
||||
self._var = torch.mean((x - self._mean) ** 2, dim=0, keepdim=True)
|
||||
# Compute batch variance
|
||||
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
|
||||
delta2 = batch_mean - self._mean # uses updated mean
|
||||
|
||||
self._std = torch.sqrt(self._var)
|
||||
self.count = new_count
|
||||
# Parallel variance update (works even when previous count == 0)
|
||||
m_a = self._var * self.count # previous aggregated M2
|
||||
m_b = batch_var * batch_size
|
||||
M2 = m_a + m_b + delta2.pow(2) * (self.count * batch_size / new_count)
|
||||
self._var.copy_(M2 / new_count)
|
||||
|
||||
# Update std and count in-place to avoid expensive __setattr__
|
||||
self._std.copy_(self._var.sqrt())
|
||||
self.count.copy_(new_count)
|
||||
|
||||
@torch.jit.unused
|
||||
def inverse(self, y):
|
||||
|
@ -92,10 +92,14 @@ class BaseArgs:
|
||||
"""the interval to render the model"""
|
||||
compile: bool = True
|
||||
"""whether to use torch.compile."""
|
||||
compile_mode: str = "reduce-overhead"
|
||||
"""the mode of torch.compile."""
|
||||
obs_normalization: bool = True
|
||||
"""whether to enable observation normalization"""
|
||||
reward_normalization: bool = False
|
||||
"""whether to enable reward normalization"""
|
||||
use_grad_norm_clipping: bool = False
|
||||
"""whether to use gradient norm clipping."""
|
||||
max_grad_norm: float = 0.0
|
||||
"""the maximum gradient norm"""
|
||||
amp: bool = True
|
||||
|
@ -49,6 +49,12 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mark_step():
|
||||
# call this once per iteration *before* any compiled function
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
print(args)
|
||||
@ -458,10 +464,13 @@ def main():
|
||||
scaler.scale(qf_loss).backward()
|
||||
scaler.unscale_(q_optimizer)
|
||||
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
qnet.parameters(),
|
||||
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
|
||||
)
|
||||
if args.use_grad_norm_clipping:
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
qnet.parameters(),
|
||||
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
|
||||
)
|
||||
else:
|
||||
critic_grad_norm = torch.tensor(0.0, device=device)
|
||||
scaler.step(q_optimizer)
|
||||
scaler.update()
|
||||
q_scheduler.step()
|
||||
@ -494,10 +503,13 @@ def main():
|
||||
actor_optimizer.zero_grad(set_to_none=True)
|
||||
scaler.scale(actor_loss).backward()
|
||||
scaler.unscale_(actor_optimizer)
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
actor.parameters(),
|
||||
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
|
||||
)
|
||||
if args.use_grad_norm_clipping:
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
actor.parameters(),
|
||||
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
|
||||
)
|
||||
else:
|
||||
actor_grad_norm = torch.tensor(0.0, device=device)
|
||||
scaler.step(actor_optimizer)
|
||||
scaler.update()
|
||||
actor_scheduler.step()
|
||||
@ -505,16 +517,28 @@ def main():
|
||||
logs_dict["actor_loss"] = actor_loss.detach()
|
||||
return logs_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def soft_update(src, tgt, tau: float):
|
||||
src_ps = [p.data for p in src.parameters()]
|
||||
tgt_ps = [p.data for p in tgt.parameters()]
|
||||
|
||||
torch._foreach_mul_(tgt_ps, 1.0 - tau)
|
||||
torch._foreach_add_(tgt_ps, src_ps, alpha=tau)
|
||||
|
||||
if args.compile:
|
||||
mode = None
|
||||
update_main = torch.compile(update_main, mode=mode)
|
||||
update_pol = torch.compile(update_pol, mode=mode)
|
||||
policy = torch.compile(policy, mode=mode)
|
||||
normalize_obs = torch.compile(obs_normalizer.forward, mode=mode)
|
||||
normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode)
|
||||
compile_mode = args.compile_mode
|
||||
update_main = torch.compile(update_main, mode=compile_mode)
|
||||
update_pol = torch.compile(update_pol, mode=compile_mode)
|
||||
policy = torch.compile(policy, mode=compile_mode)
|
||||
normalize_obs = torch.compile(obs_normalizer.forward, mode=compile_mode)
|
||||
normalize_critic_obs = torch.compile(
|
||||
critic_obs_normalizer.forward, mode=compile_mode
|
||||
)
|
||||
if args.reward_normalization:
|
||||
update_stats = torch.compile(reward_normalizer.update_stats, mode=mode)
|
||||
normalize_reward = torch.compile(reward_normalizer.forward, mode=mode)
|
||||
update_stats = torch.compile(
|
||||
reward_normalizer.update_stats, mode=compile_mode
|
||||
)
|
||||
normalize_reward = torch.compile(reward_normalizer.forward, mode=compile_mode)
|
||||
else:
|
||||
normalize_obs = obs_normalizer.forward
|
||||
normalize_critic_obs = critic_obs_normalizer.forward
|
||||
@ -549,6 +573,7 @@ def main():
|
||||
desc = ""
|
||||
|
||||
while global_step < args.total_timesteps:
|
||||
mark_step()
|
||||
logs_dict = TensorDict()
|
||||
if (
|
||||
start_time is None
|
||||
@ -576,7 +601,6 @@ def main():
|
||||
|
||||
if envs.asymmetric_obs:
|
||||
next_critic_obs = infos["observations"]["critic"]
|
||||
|
||||
# Compute 'true' next_obs and next_critic_obs for saving
|
||||
true_next_obs = torch.where(
|
||||
dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs
|
||||
@ -587,6 +611,7 @@ def main():
|
||||
infos["observations"]["raw"]["critic_obs"],
|
||||
next_critic_obs,
|
||||
)
|
||||
|
||||
transition = TensorDict(
|
||||
{
|
||||
"observations": obs,
|
||||
@ -606,13 +631,12 @@ def main():
|
||||
if envs.asymmetric_obs:
|
||||
transition["critic_observations"] = critic_obs
|
||||
transition["next"]["critic_observations"] = true_next_critic_obs
|
||||
rb.extend(transition)
|
||||
|
||||
obs = next_obs
|
||||
if envs.asymmetric_obs:
|
||||
critic_obs = next_critic_obs
|
||||
|
||||
rb.extend(transition)
|
||||
|
||||
batch_size = args.batch_size // args.num_envs
|
||||
if global_step > args.learning_starts:
|
||||
for i in range(args.num_updates):
|
||||
@ -621,6 +645,13 @@ def main():
|
||||
data["next"]["observations"] = normalize_obs(
|
||||
data["next"]["observations"]
|
||||
)
|
||||
if envs.asymmetric_obs:
|
||||
data["critic_observations"] = normalize_critic_obs(
|
||||
data["critic_observations"]
|
||||
)
|
||||
data["next"]["critic_observations"] = normalize_critic_obs(
|
||||
data["next"]["critic_observations"]
|
||||
)
|
||||
raw_rewards = data["next"]["rewards"]
|
||||
if env_type in ["mtbench"] and args.reward_normalization:
|
||||
# Multi-task reward normalization
|
||||
@ -631,13 +662,7 @@ def main():
|
||||
)
|
||||
else:
|
||||
data["next"]["rewards"] = normalize_reward(raw_rewards)
|
||||
if envs.asymmetric_obs:
|
||||
data["critic_observations"] = normalize_critic_obs(
|
||||
data["critic_observations"]
|
||||
)
|
||||
data["next"]["critic_observations"] = normalize_critic_obs(
|
||||
data["next"]["critic_observations"]
|
||||
)
|
||||
|
||||
logs_dict = update_main(data, logs_dict)
|
||||
if args.num_updates > 1:
|
||||
if i % args.policy_frequency == 1:
|
||||
@ -646,12 +671,7 @@ def main():
|
||||
if global_step % args.policy_frequency == 0:
|
||||
logs_dict = update_pol(data, logs_dict)
|
||||
|
||||
for param, target_param in zip(
|
||||
qnet.parameters(), qnet_target.parameters()
|
||||
):
|
||||
target_param.data.copy_(
|
||||
args.tau * param.data + (1 - args.tau) * target_param.data
|
||||
)
|
||||
soft_update(qnet, qnet_target, args.tau)
|
||||
|
||||
if global_step % 100 == 0 and start_time is not None:
|
||||
speed = (global_step - measure_burnin) / (time.time() - start_time)
|
||||
|
Loading…
Reference in New Issue
Block a user