Fix replay buffer issues when n_steps > 1 (#7)

- Fix an issue where the n-step reward is not properly computed for end-of-episode transitions when using n_step > 1.
- Fix an issue where the observation and next_observations are sampled across different episodes when using n_step > 1 and the buffer is full
- Fix an issue where the discount is not properly computed when n_step > 1
This commit is contained in:
Younggyo Seo 2025-06-07 01:20:48 -04:00 committed by GitHub
parent fe028b578f
commit 85cb1c65c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 67 additions and 26 deletions

View File

@ -8,6 +8,13 @@ FastTD3 is a high-performance variant of the Twin Delayed Deep Deterministic Pol
For more information, please see our [project webpage](https://younggyo.me/fast_td3)
## ❗ Updates
- **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
- **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
## ✨ Features
FastTD3 offers researchers a significant speedup in training complex humanoid agents.
@ -151,9 +158,6 @@ We would like to thank people who have helped throughout the project:
- We thank [Kevin Zakka](https://kzakka.com/) for the help in setting up MuJoCo Playground.
- We thank [Changyeon Kim](https://changyeon.site/) for testing the early version of this codebase
## ❗ Updates
- **[Jun/1/2025]** We updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
## Citations
### FastTD3

View File

@ -39,14 +39,17 @@ class DistributionalQNetwork(nn.Module):
actions: torch.Tensor,
rewards: torch.Tensor,
bootstrap: torch.Tensor,
gamma: float,
discount: float,
q_support: torch.Tensor,
device: torch.device,
) -> torch.Tensor:
delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
batch_size = rewards.shape[0]
target_z = rewards.unsqueeze(1) + bootstrap.unsqueeze(1) * gamma * q_support
target_z = (
rewards.unsqueeze(1)
+ bootstrap.unsqueeze(1) * discount.unsqueeze(1) * q_support
)
target_z = target_z.clamp(self.v_min, self.v_max)
b = (target_z - self.v_min) / delta_z
l = torch.floor(b).long()
@ -121,7 +124,7 @@ class Critic(nn.Module):
actions: torch.Tensor,
rewards: torch.Tensor,
bootstrap: torch.Tensor,
gamma: float,
discount: float,
) -> torch.Tensor:
"""Projection operation that includes q_support directly"""
q1_proj = self.qnet1.projection(
@ -129,7 +132,7 @@ class Critic(nn.Module):
actions,
rewards,
bootstrap,
gamma,
discount,
self.q_support,
self.q_support.device,
)
@ -138,7 +141,7 @@ class Critic(nn.Module):
actions,
rewards,
bootstrap,
gamma,
discount,
self.q_support,
self.q_support.device,
)

View File

@ -27,6 +27,8 @@ class SimpleReplayBuffer(nn.Module):
When playground_mode=True, critic_observations are treated as a concatenation of
regular observations and privileged observations, and only the privileged part is stored
to save memory.
TODO (Younggyo): Refactor to split this into SimpleReplayBuffer and NStepReplayBuffer
"""
super().__init__()
@ -146,6 +148,7 @@ class SimpleReplayBuffer(nn.Module):
truncations = torch.gather(self.truncations, 1, indices).reshape(
self.n_env * batch_size
)
effective_n_steps = torch.ones_like(dones)
if self.asymmetric_obs:
if self.playground_mode:
# Gather privileged observations
@ -179,12 +182,31 @@ class SimpleReplayBuffer(nn.Module):
).reshape(self.n_env * batch_size, self.n_critic_obs)
else:
# Sample base indices
indices = torch.randint(
0,
min(self.buffer_size, self.ptr),
(self.n_env, batch_size),
device=self.device,
)
if self.ptr >= self.buffer_size:
# When the buffer is full, there is no protection against sampling across different episodes
# We avoid this by temporarily setting self.pos - 1 to truncated = True if not done
# https://github.com/DLR-RM/stable-baselines3/blob/b91050ca94f8bce7a0285c91f85da518d5a26223/stable_baselines3/common/buffers.py#L857-L860
# TODO (Younggyo): Change the reference when this SB3 branch is merged
current_pos = self.ptr % self.buffer_size
curr_truncations = self.truncations[:, current_pos - 1].clone()
self.truncations[:, current_pos - 1] = torch.logical_not(
self.dones[:, current_pos - 1]
)
indices = torch.randint(
0,
self.buffer_size,
(self.n_env, batch_size),
device=self.device,
)
else:
# Buffer not full - ensure n-step sequence doesn't exceed valid data
max_start_idx = max(1, self.ptr - self.n_steps + 1)
indices = torch.randint(
0,
max_start_idx,
(self.n_env, batch_size),
device=self.device,
)
obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act)
@ -239,11 +261,15 @@ class SimpleReplayBuffer(nn.Module):
all_indices,
)
# Create masks for rewards after first done
# Create masks for rewards *after* first done
# This creates a cumulative product that zeroes out rewards after the first done
all_dones_shifted = torch.cat(
[torch.zeros_like(all_dones[:, :, :1]), all_dones[:, :, :-1]], dim=2
) # First reward should not be masked
done_masks = torch.cumprod(
1.0 - all_dones, dim=2
1.0 - all_dones_shifted, dim=2
) # [n_env, batch_size, n_step]
effective_n_steps = done_masks.sum(2)
# Create discount factors
discounts = torch.pow(
@ -339,6 +365,7 @@ class SimpleReplayBuffer(nn.Module):
rewards = n_step_rewards.reshape(self.n_env * batch_size)
dones = final_dones.reshape(self.n_env * batch_size)
truncations = final_truncations.reshape(self.n_env * batch_size)
effective_n_steps = effective_n_steps.reshape(self.n_env * batch_size)
next_observations = final_next_observations.reshape(
self.n_env * batch_size, self.n_obs
)
@ -352,6 +379,7 @@ class SimpleReplayBuffer(nn.Module):
"dones": dones,
"truncations": truncations,
"observations": next_observations,
"effective_n_steps": effective_n_steps,
},
},
batch_size=self.n_env * batch_size,
@ -359,6 +387,10 @@ class SimpleReplayBuffer(nn.Module):
if self.asymmetric_obs:
out["critic_observations"] = critic_observations
out["next"]["critic_observations"] = next_critic_observations
if self.ptr >= self.buffer_size:
# Roll back the truncation flags introduced for safe sampling
self.truncations[:, current_pos - 1] = curr_truncations
return out
@ -406,7 +438,6 @@ class EmpiricalNormalization(nn.Module):
@torch.jit.unused
def update(self, x):
"""Learn input values using Welford's online algorithm"""
if self.until is not None and self.count >= self.until:
return
@ -420,12 +451,10 @@ class EmpiricalNormalization(nn.Module):
delta = batch_mean - self._mean
self._mean += (batch_size / new_count) * delta
# Update variance using Welford's parallel algorithm
# Update variance using Chan's parallel algorithm
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
if self.count > 0: # Ensure we're not dividing by zero
# Compute batch variance
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
# Combine variances using parallel algorithm
delta2 = batch_mean - self._mean
m_a = self._var * self.count
m_b = batch_var * batch_size

View File

@ -411,21 +411,24 @@ class IsaacOpenDrawerFrankaArgs(IsaacLabArgs):
@dataclass
class IsaacVelocityFlatH1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Flat-H1-v0"
num_steps: int = 3
num_steps: int = 8
num_updates: int = 4
total_timesteps: int = 75000
@dataclass
class IsaacVelocityFlatG1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Flat-G1-v0"
num_steps: int = 3
num_steps: int = 8
num_updates: int = 4
total_timesteps: int = 50000
@dataclass
class IsaacVelocityRoughH1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Rough-H1-v0"
num_steps: int = 3
num_steps: int = 8
num_updates: int = 4
buffer_size: int = 1024 * 5 # To reduce memory usage
total_timesteps: int = 50000
@ -433,7 +436,8 @@ class IsaacVelocityRoughH1Args(IsaacLabArgs):
@dataclass
class IsaacVelocityRoughG1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Rough-G1-v0"
num_steps: int = 3
num_steps: int = 8
num_updates: int = 4
buffer_size: int = 1024 * 5 # To reduce memory usage
total_timesteps: int = 50000

View File

@ -305,6 +305,7 @@ def main():
next_state_actions = (actor(next_observations) + clipped_noise).clamp(
action_low, action_high
)
discount = args.gamma ** data["next"]["effective_n_steps"]
with torch.no_grad():
qf1_next_target_projected, qf2_next_target_projected = (
@ -313,7 +314,7 @@ def main():
next_state_actions,
rewards,
bootstrap,
args.gamma,
discount,
)
)
qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)