Optimized codebase to speed up training (#20)
- Modified codes to be compatible with torch.compile - Modified empirical normalizer to use in-place operator to avoid costly __setattr__ - Parallel soft Q-update - As a default option, we disabled gradient norm clipping as it is quite expensive
This commit is contained in:
parent
799624b202
commit
c354ead107
@ -10,6 +10,9 @@ For more information, please see our [project webpage](https://younggyo.me/fast_
|
|||||||
|
|
||||||
|
|
||||||
## ❗ Updates
|
## ❗ Updates
|
||||||
|
|
||||||
|
- **[Jul/02/2025]** Optimized codebase to speed up training around 10-30% when using a single RTX 4090 GPU.
|
||||||
|
|
||||||
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/).
|
- **[Jun/20/2025]** Added support for [MTBench](https://github.com/Viraj-Joshi/MTBench) with the help of [Viraj Joshi](https://viraj-joshi.github.io/).
|
||||||
|
|
||||||
- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases.
|
- **[Jun/15/2025]** Added support for FastTD3 + [SimbaV2](https://dojeon-ai.github.io/SimbaV2/)! It's faster to train, and often achieves better asymptotic performance. We recommend using FastTD3 + SimbaV2 for most cases.
|
||||||
@ -237,6 +240,7 @@ We used a single Nvidia A100 80GB GPU for all experiments. Here are some remarks
|
|||||||
- If the agent is completely stuck or much worse than your expectation, try using `num_steps=3` or disabling `use_cdq`.
|
- If the agent is completely stuck or much worse than your expectation, try using `num_steps=3` or disabling `use_cdq`.
|
||||||
- For tasks that have penalty reward terms (e.g., torques, energy, action_rate, ..), consider lowering them for initial experiments, and tune the values. In some cases, curriculum learning with lower penalty terms followed by fine-tuning with stronger terms is effective.
|
- For tasks that have penalty reward terms (e.g., torques, energy, action_rate, ..), consider lowering them for initial experiments, and tune the values. In some cases, curriculum learning with lower penalty terms followed by fine-tuning with stronger terms is effective.
|
||||||
- When you encounter out-of-memory error with your GPU, our recommendation for reducing GPU usage is (i) smaller `buffer_size`, (ii) smaller `batch_size`, and then (iii) smaller `num_envs`. Because our codebase is assigning the whole replay buffer in GPU to reduce CPU-GPU transfer bottleneck, it usually has the largest GPU consumption, but usually less harmful to reduce.
|
- When you encounter out-of-memory error with your GPU, our recommendation for reducing GPU usage is (i) smaller `buffer_size`, (ii) smaller `batch_size`, and then (iii) smaller `num_envs`. Because our codebase is assigning the whole replay buffer in GPU to reduce CPU-GPU transfer bottleneck, it usually has the largest GPU consumption, but usually less harmful to reduce.
|
||||||
|
- Consider using `--compile_mode max-autotune` if you plan to run for many training steps. This may speed up training by up to 10% at the cost of a few additional minutes of heavy compilation.
|
||||||
|
|
||||||
## 🛝 Playing with the FastTD3 training
|
## 🛝 Playing with the FastTD3 training
|
||||||
|
|
||||||
|
@ -212,7 +212,9 @@ class Actor(nn.Module):
|
|||||||
|
|
||||||
# Update only the noise scales for environments that are done
|
# Update only the noise scales for environments that are done
|
||||||
dones_view = dones.view(-1, 1) > 0
|
dones_view = dones.view(-1, 1) > 0
|
||||||
self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales)
|
self.noise_scales.copy_(
|
||||||
|
torch.where(dones_view, new_scales, self.noise_scales)
|
||||||
|
)
|
||||||
|
|
||||||
act = self(obs)
|
act = self(obs)
|
||||||
if deterministic:
|
if deterministic:
|
||||||
|
@ -486,7 +486,9 @@ class Actor(nn.Module):
|
|||||||
|
|
||||||
# Update only the noise scales for environments that are done
|
# Update only the noise scales for environments that are done
|
||||||
dones_view = dones.view(-1, 1) > 0
|
dones_view = dones.view(-1, 1) > 0
|
||||||
self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales)
|
self.noise_scales.copy_(
|
||||||
|
torch.where(dones_view, new_scales, self.noise_scales)
|
||||||
|
)
|
||||||
|
|
||||||
act = self(obs)
|
act = self(obs)
|
||||||
if deterministic:
|
if deterministic:
|
||||||
|
@ -85,6 +85,7 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.ptr = 0
|
self.ptr = 0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def extend(
|
def extend(
|
||||||
self,
|
self,
|
||||||
tensor_dict: TensorDict,
|
tensor_dict: TensorDict,
|
||||||
@ -119,6 +120,7 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
self.next_critic_observations[:, ptr] = next_critic_observations
|
self.next_critic_observations[:, ptr] = next_critic_observations
|
||||||
self.ptr += 1
|
self.ptr += 1
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def sample(self, batch_size: int):
|
def sample(self, batch_size: int):
|
||||||
# we will sample n_env * batch_size transitions
|
# we will sample n_env * batch_size transitions
|
||||||
|
|
||||||
@ -425,6 +427,7 @@ class EmpiricalNormalization(nn.Module):
|
|||||||
def std(self):
|
def std(self):
|
||||||
return self._std.squeeze(0).clone()
|
return self._std.squeeze(0).clone()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(self, x: torch.Tensor, center: bool = True) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, center: bool = True) -> torch.Tensor:
|
||||||
if x.shape[1:] != self._mean.shape[1:]:
|
if x.shape[1:] != self._mean.shape[1:]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -453,21 +456,19 @@ class EmpiricalNormalization(nn.Module):
|
|||||||
delta = batch_mean - self._mean
|
delta = batch_mean - self._mean
|
||||||
self._mean += (batch_size / new_count) * delta
|
self._mean += (batch_size / new_count) * delta
|
||||||
|
|
||||||
# Update variance using Chan's parallel algorithm
|
# Compute batch variance
|
||||||
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
|
||||||
if self.count > 0: # Ensure we're not dividing by zero
|
delta2 = batch_mean - self._mean # uses updated mean
|
||||||
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
|
|
||||||
delta2 = batch_mean - self._mean
|
|
||||||
m_a = self._var * self.count
|
|
||||||
m_b = batch_var * batch_size
|
|
||||||
M2 = m_a + m_b + (delta2**2) * (self.count * batch_size / new_count)
|
|
||||||
self._var = M2 / new_count
|
|
||||||
else:
|
|
||||||
# For first batch, just use batch variance
|
|
||||||
self._var = torch.mean((x - self._mean) ** 2, dim=0, keepdim=True)
|
|
||||||
|
|
||||||
self._std = torch.sqrt(self._var)
|
# Parallel variance update (works even when previous count == 0)
|
||||||
self.count = new_count
|
m_a = self._var * self.count # previous aggregated M2
|
||||||
|
m_b = batch_var * batch_size
|
||||||
|
M2 = m_a + m_b + delta2.pow(2) * (self.count * batch_size / new_count)
|
||||||
|
self._var.copy_(M2 / new_count)
|
||||||
|
|
||||||
|
# Update std and count in-place to avoid expensive __setattr__
|
||||||
|
self._std.copy_(self._var.sqrt())
|
||||||
|
self.count.copy_(new_count)
|
||||||
|
|
||||||
@torch.jit.unused
|
@torch.jit.unused
|
||||||
def inverse(self, y):
|
def inverse(self, y):
|
||||||
|
@ -92,10 +92,14 @@ class BaseArgs:
|
|||||||
"""the interval to render the model"""
|
"""the interval to render the model"""
|
||||||
compile: bool = True
|
compile: bool = True
|
||||||
"""whether to use torch.compile."""
|
"""whether to use torch.compile."""
|
||||||
|
compile_mode: str = "reduce-overhead"
|
||||||
|
"""the mode of torch.compile."""
|
||||||
obs_normalization: bool = True
|
obs_normalization: bool = True
|
||||||
"""whether to enable observation normalization"""
|
"""whether to enable observation normalization"""
|
||||||
reward_normalization: bool = False
|
reward_normalization: bool = False
|
||||||
"""whether to enable reward normalization"""
|
"""whether to enable reward normalization"""
|
||||||
|
use_grad_norm_clipping: bool = False
|
||||||
|
"""whether to use gradient norm clipping."""
|
||||||
max_grad_norm: float = 0.0
|
max_grad_norm: float = 0.0
|
||||||
"""the maximum gradient norm"""
|
"""the maximum gradient norm"""
|
||||||
amp: bool = True
|
amp: bool = True
|
||||||
|
@ -49,6 +49,12 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def mark_step():
|
||||||
|
# call this once per iteration *before* any compiled function
|
||||||
|
torch.compiler.cudagraph_mark_step_begin()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
print(args)
|
print(args)
|
||||||
@ -458,10 +464,13 @@ def main():
|
|||||||
scaler.scale(qf_loss).backward()
|
scaler.scale(qf_loss).backward()
|
||||||
scaler.unscale_(q_optimizer)
|
scaler.unscale_(q_optimizer)
|
||||||
|
|
||||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
if args.use_grad_norm_clipping:
|
||||||
qnet.parameters(),
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
|
qnet.parameters(),
|
||||||
)
|
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
critic_grad_norm = torch.tensor(0.0, device=device)
|
||||||
scaler.step(q_optimizer)
|
scaler.step(q_optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
q_scheduler.step()
|
q_scheduler.step()
|
||||||
@ -494,10 +503,13 @@ def main():
|
|||||||
actor_optimizer.zero_grad(set_to_none=True)
|
actor_optimizer.zero_grad(set_to_none=True)
|
||||||
scaler.scale(actor_loss).backward()
|
scaler.scale(actor_loss).backward()
|
||||||
scaler.unscale_(actor_optimizer)
|
scaler.unscale_(actor_optimizer)
|
||||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
if args.use_grad_norm_clipping:
|
||||||
actor.parameters(),
|
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
|
actor.parameters(),
|
||||||
)
|
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
actor_grad_norm = torch.tensor(0.0, device=device)
|
||||||
scaler.step(actor_optimizer)
|
scaler.step(actor_optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
actor_scheduler.step()
|
actor_scheduler.step()
|
||||||
@ -505,16 +517,28 @@ def main():
|
|||||||
logs_dict["actor_loss"] = actor_loss.detach()
|
logs_dict["actor_loss"] = actor_loss.detach()
|
||||||
return logs_dict
|
return logs_dict
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def soft_update(src, tgt, tau: float):
|
||||||
|
src_ps = [p.data for p in src.parameters()]
|
||||||
|
tgt_ps = [p.data for p in tgt.parameters()]
|
||||||
|
|
||||||
|
torch._foreach_mul_(tgt_ps, 1.0 - tau)
|
||||||
|
torch._foreach_add_(tgt_ps, src_ps, alpha=tau)
|
||||||
|
|
||||||
if args.compile:
|
if args.compile:
|
||||||
mode = None
|
compile_mode = args.compile_mode
|
||||||
update_main = torch.compile(update_main, mode=mode)
|
update_main = torch.compile(update_main, mode=compile_mode)
|
||||||
update_pol = torch.compile(update_pol, mode=mode)
|
update_pol = torch.compile(update_pol, mode=compile_mode)
|
||||||
policy = torch.compile(policy, mode=mode)
|
policy = torch.compile(policy, mode=compile_mode)
|
||||||
normalize_obs = torch.compile(obs_normalizer.forward, mode=mode)
|
normalize_obs = torch.compile(obs_normalizer.forward, mode=compile_mode)
|
||||||
normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode)
|
normalize_critic_obs = torch.compile(
|
||||||
|
critic_obs_normalizer.forward, mode=compile_mode
|
||||||
|
)
|
||||||
if args.reward_normalization:
|
if args.reward_normalization:
|
||||||
update_stats = torch.compile(reward_normalizer.update_stats, mode=mode)
|
update_stats = torch.compile(
|
||||||
normalize_reward = torch.compile(reward_normalizer.forward, mode=mode)
|
reward_normalizer.update_stats, mode=compile_mode
|
||||||
|
)
|
||||||
|
normalize_reward = torch.compile(reward_normalizer.forward, mode=compile_mode)
|
||||||
else:
|
else:
|
||||||
normalize_obs = obs_normalizer.forward
|
normalize_obs = obs_normalizer.forward
|
||||||
normalize_critic_obs = critic_obs_normalizer.forward
|
normalize_critic_obs = critic_obs_normalizer.forward
|
||||||
@ -549,6 +573,7 @@ def main():
|
|||||||
desc = ""
|
desc = ""
|
||||||
|
|
||||||
while global_step < args.total_timesteps:
|
while global_step < args.total_timesteps:
|
||||||
|
mark_step()
|
||||||
logs_dict = TensorDict()
|
logs_dict = TensorDict()
|
||||||
if (
|
if (
|
||||||
start_time is None
|
start_time is None
|
||||||
@ -576,7 +601,6 @@ def main():
|
|||||||
|
|
||||||
if envs.asymmetric_obs:
|
if envs.asymmetric_obs:
|
||||||
next_critic_obs = infos["observations"]["critic"]
|
next_critic_obs = infos["observations"]["critic"]
|
||||||
|
|
||||||
# Compute 'true' next_obs and next_critic_obs for saving
|
# Compute 'true' next_obs and next_critic_obs for saving
|
||||||
true_next_obs = torch.where(
|
true_next_obs = torch.where(
|
||||||
dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs
|
dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs
|
||||||
@ -587,6 +611,7 @@ def main():
|
|||||||
infos["observations"]["raw"]["critic_obs"],
|
infos["observations"]["raw"]["critic_obs"],
|
||||||
next_critic_obs,
|
next_critic_obs,
|
||||||
)
|
)
|
||||||
|
|
||||||
transition = TensorDict(
|
transition = TensorDict(
|
||||||
{
|
{
|
||||||
"observations": obs,
|
"observations": obs,
|
||||||
@ -606,13 +631,12 @@ def main():
|
|||||||
if envs.asymmetric_obs:
|
if envs.asymmetric_obs:
|
||||||
transition["critic_observations"] = critic_obs
|
transition["critic_observations"] = critic_obs
|
||||||
transition["next"]["critic_observations"] = true_next_critic_obs
|
transition["next"]["critic_observations"] = true_next_critic_obs
|
||||||
|
rb.extend(transition)
|
||||||
|
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
if envs.asymmetric_obs:
|
if envs.asymmetric_obs:
|
||||||
critic_obs = next_critic_obs
|
critic_obs = next_critic_obs
|
||||||
|
|
||||||
rb.extend(transition)
|
|
||||||
|
|
||||||
batch_size = args.batch_size // args.num_envs
|
batch_size = args.batch_size // args.num_envs
|
||||||
if global_step > args.learning_starts:
|
if global_step > args.learning_starts:
|
||||||
for i in range(args.num_updates):
|
for i in range(args.num_updates):
|
||||||
@ -621,6 +645,13 @@ def main():
|
|||||||
data["next"]["observations"] = normalize_obs(
|
data["next"]["observations"] = normalize_obs(
|
||||||
data["next"]["observations"]
|
data["next"]["observations"]
|
||||||
)
|
)
|
||||||
|
if envs.asymmetric_obs:
|
||||||
|
data["critic_observations"] = normalize_critic_obs(
|
||||||
|
data["critic_observations"]
|
||||||
|
)
|
||||||
|
data["next"]["critic_observations"] = normalize_critic_obs(
|
||||||
|
data["next"]["critic_observations"]
|
||||||
|
)
|
||||||
raw_rewards = data["next"]["rewards"]
|
raw_rewards = data["next"]["rewards"]
|
||||||
if env_type in ["mtbench"] and args.reward_normalization:
|
if env_type in ["mtbench"] and args.reward_normalization:
|
||||||
# Multi-task reward normalization
|
# Multi-task reward normalization
|
||||||
@ -631,13 +662,7 @@ def main():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data["next"]["rewards"] = normalize_reward(raw_rewards)
|
data["next"]["rewards"] = normalize_reward(raw_rewards)
|
||||||
if envs.asymmetric_obs:
|
|
||||||
data["critic_observations"] = normalize_critic_obs(
|
|
||||||
data["critic_observations"]
|
|
||||||
)
|
|
||||||
data["next"]["critic_observations"] = normalize_critic_obs(
|
|
||||||
data["next"]["critic_observations"]
|
|
||||||
)
|
|
||||||
logs_dict = update_main(data, logs_dict)
|
logs_dict = update_main(data, logs_dict)
|
||||||
if args.num_updates > 1:
|
if args.num_updates > 1:
|
||||||
if i % args.policy_frequency == 1:
|
if i % args.policy_frequency == 1:
|
||||||
@ -646,12 +671,7 @@ def main():
|
|||||||
if global_step % args.policy_frequency == 0:
|
if global_step % args.policy_frequency == 0:
|
||||||
logs_dict = update_pol(data, logs_dict)
|
logs_dict = update_pol(data, logs_dict)
|
||||||
|
|
||||||
for param, target_param in zip(
|
soft_update(qnet, qnet_target, args.tau)
|
||||||
qnet.parameters(), qnet_target.parameters()
|
|
||||||
):
|
|
||||||
target_param.data.copy_(
|
|
||||||
args.tau * param.data + (1 - args.tau) * target_param.data
|
|
||||||
)
|
|
||||||
|
|
||||||
if global_step % 100 == 0 and start_time is not None:
|
if global_step % 100 == 0 and start_time is not None:
|
||||||
speed = (global_step - measure_burnin) / (time.time() - start_time)
|
speed = (global_step - measure_burnin) / (time.time() - start_time)
|
||||||
|
Loading…
Reference in New Issue
Block a user