Fix replay buffer issues when n_steps > 1 (#7)

- Fix an issue where the n-step reward is not properly computed for end-of-episode transitions when using n_step > 1.
- Fix an issue where the observation and next_observations are sampled across different episodes when using n_step > 1 and the buffer is full
- Fix an issue where the discount is not properly computed when n_step > 1
This commit is contained in:
Younggyo Seo 2025-06-07 01:20:48 -04:00 committed by GitHub
parent fe028b578f
commit 85cb1c65c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 67 additions and 26 deletions

View File

@ -8,6 +8,13 @@ FastTD3 is a high-performance variant of the Twin Delayed Deep Deterministic Pol
For more information, please see our [project webpage](https://younggyo.me/fast_td3) For more information, please see our [project webpage](https://younggyo.me/fast_td3)
## ❗ Updates
- **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
- **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
## ✨ Features ## ✨ Features
FastTD3 offers researchers a significant speedup in training complex humanoid agents. FastTD3 offers researchers a significant speedup in training complex humanoid agents.
@ -151,9 +158,6 @@ We would like to thank people who have helped throughout the project:
- We thank [Kevin Zakka](https://kzakka.com/) for the help in setting up MuJoCo Playground. - We thank [Kevin Zakka](https://kzakka.com/) for the help in setting up MuJoCo Playground.
- We thank [Changyeon Kim](https://changyeon.site/) for testing the early version of this codebase - We thank [Changyeon Kim](https://changyeon.site/) for testing the early version of this codebase
## ❗ Updates
- **[Jun/1/2025]** We updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
## Citations ## Citations
### FastTD3 ### FastTD3

View File

@ -39,14 +39,17 @@ class DistributionalQNetwork(nn.Module):
actions: torch.Tensor, actions: torch.Tensor,
rewards: torch.Tensor, rewards: torch.Tensor,
bootstrap: torch.Tensor, bootstrap: torch.Tensor,
gamma: float, discount: float,
q_support: torch.Tensor, q_support: torch.Tensor,
device: torch.device, device: torch.device,
) -> torch.Tensor: ) -> torch.Tensor:
delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
batch_size = rewards.shape[0] batch_size = rewards.shape[0]
target_z = rewards.unsqueeze(1) + bootstrap.unsqueeze(1) * gamma * q_support target_z = (
rewards.unsqueeze(1)
+ bootstrap.unsqueeze(1) * discount.unsqueeze(1) * q_support
)
target_z = target_z.clamp(self.v_min, self.v_max) target_z = target_z.clamp(self.v_min, self.v_max)
b = (target_z - self.v_min) / delta_z b = (target_z - self.v_min) / delta_z
l = torch.floor(b).long() l = torch.floor(b).long()
@ -121,7 +124,7 @@ class Critic(nn.Module):
actions: torch.Tensor, actions: torch.Tensor,
rewards: torch.Tensor, rewards: torch.Tensor,
bootstrap: torch.Tensor, bootstrap: torch.Tensor,
gamma: float, discount: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Projection operation that includes q_support directly""" """Projection operation that includes q_support directly"""
q1_proj = self.qnet1.projection( q1_proj = self.qnet1.projection(
@ -129,7 +132,7 @@ class Critic(nn.Module):
actions, actions,
rewards, rewards,
bootstrap, bootstrap,
gamma, discount,
self.q_support, self.q_support,
self.q_support.device, self.q_support.device,
) )
@ -138,7 +141,7 @@ class Critic(nn.Module):
actions, actions,
rewards, rewards,
bootstrap, bootstrap,
gamma, discount,
self.q_support, self.q_support,
self.q_support.device, self.q_support.device,
) )

View File

@ -27,6 +27,8 @@ class SimpleReplayBuffer(nn.Module):
When playground_mode=True, critic_observations are treated as a concatenation of When playground_mode=True, critic_observations are treated as a concatenation of
regular observations and privileged observations, and only the privileged part is stored regular observations and privileged observations, and only the privileged part is stored
to save memory. to save memory.
TODO (Younggyo): Refactor to split this into SimpleReplayBuffer and NStepReplayBuffer
""" """
super().__init__() super().__init__()
@ -146,6 +148,7 @@ class SimpleReplayBuffer(nn.Module):
truncations = torch.gather(self.truncations, 1, indices).reshape( truncations = torch.gather(self.truncations, 1, indices).reshape(
self.n_env * batch_size self.n_env * batch_size
) )
effective_n_steps = torch.ones_like(dones)
if self.asymmetric_obs: if self.asymmetric_obs:
if self.playground_mode: if self.playground_mode:
# Gather privileged observations # Gather privileged observations
@ -179,12 +182,31 @@ class SimpleReplayBuffer(nn.Module):
).reshape(self.n_env * batch_size, self.n_critic_obs) ).reshape(self.n_env * batch_size, self.n_critic_obs)
else: else:
# Sample base indices # Sample base indices
indices = torch.randint( if self.ptr >= self.buffer_size:
0, # When the buffer is full, there is no protection against sampling across different episodes
min(self.buffer_size, self.ptr), # We avoid this by temporarily setting self.pos - 1 to truncated = True if not done
(self.n_env, batch_size), # https://github.com/DLR-RM/stable-baselines3/blob/b91050ca94f8bce7a0285c91f85da518d5a26223/stable_baselines3/common/buffers.py#L857-L860
device=self.device, # TODO (Younggyo): Change the reference when this SB3 branch is merged
) current_pos = self.ptr % self.buffer_size
curr_truncations = self.truncations[:, current_pos - 1].clone()
self.truncations[:, current_pos - 1] = torch.logical_not(
self.dones[:, current_pos - 1]
)
indices = torch.randint(
0,
self.buffer_size,
(self.n_env, batch_size),
device=self.device,
)
else:
# Buffer not full - ensure n-step sequence doesn't exceed valid data
max_start_idx = max(1, self.ptr - self.n_steps + 1)
indices = torch.randint(
0,
max_start_idx,
(self.n_env, batch_size),
device=self.device,
)
obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs) obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act) act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act)
@ -239,11 +261,15 @@ class SimpleReplayBuffer(nn.Module):
all_indices, all_indices,
) )
# Create masks for rewards after first done # Create masks for rewards *after* first done
# This creates a cumulative product that zeroes out rewards after the first done # This creates a cumulative product that zeroes out rewards after the first done
all_dones_shifted = torch.cat(
[torch.zeros_like(all_dones[:, :, :1]), all_dones[:, :, :-1]], dim=2
) # First reward should not be masked
done_masks = torch.cumprod( done_masks = torch.cumprod(
1.0 - all_dones, dim=2 1.0 - all_dones_shifted, dim=2
) # [n_env, batch_size, n_step] ) # [n_env, batch_size, n_step]
effective_n_steps = done_masks.sum(2)
# Create discount factors # Create discount factors
discounts = torch.pow( discounts = torch.pow(
@ -339,6 +365,7 @@ class SimpleReplayBuffer(nn.Module):
rewards = n_step_rewards.reshape(self.n_env * batch_size) rewards = n_step_rewards.reshape(self.n_env * batch_size)
dones = final_dones.reshape(self.n_env * batch_size) dones = final_dones.reshape(self.n_env * batch_size)
truncations = final_truncations.reshape(self.n_env * batch_size) truncations = final_truncations.reshape(self.n_env * batch_size)
effective_n_steps = effective_n_steps.reshape(self.n_env * batch_size)
next_observations = final_next_observations.reshape( next_observations = final_next_observations.reshape(
self.n_env * batch_size, self.n_obs self.n_env * batch_size, self.n_obs
) )
@ -352,6 +379,7 @@ class SimpleReplayBuffer(nn.Module):
"dones": dones, "dones": dones,
"truncations": truncations, "truncations": truncations,
"observations": next_observations, "observations": next_observations,
"effective_n_steps": effective_n_steps,
}, },
}, },
batch_size=self.n_env * batch_size, batch_size=self.n_env * batch_size,
@ -359,6 +387,10 @@ class SimpleReplayBuffer(nn.Module):
if self.asymmetric_obs: if self.asymmetric_obs:
out["critic_observations"] = critic_observations out["critic_observations"] = critic_observations
out["next"]["critic_observations"] = next_critic_observations out["next"]["critic_observations"] = next_critic_observations
if self.ptr >= self.buffer_size:
# Roll back the truncation flags introduced for safe sampling
self.truncations[:, current_pos - 1] = curr_truncations
return out return out
@ -406,7 +438,6 @@ class EmpiricalNormalization(nn.Module):
@torch.jit.unused @torch.jit.unused
def update(self, x): def update(self, x):
"""Learn input values using Welford's online algorithm"""
if self.until is not None and self.count >= self.until: if self.until is not None and self.count >= self.until:
return return
@ -420,12 +451,10 @@ class EmpiricalNormalization(nn.Module):
delta = batch_mean - self._mean delta = batch_mean - self._mean
self._mean += (batch_size / new_count) * delta self._mean += (batch_size / new_count) * delta
# Update variance using Welford's parallel algorithm # Update variance using Chan's parallel algorithm
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
if self.count > 0: # Ensure we're not dividing by zero if self.count > 0: # Ensure we're not dividing by zero
# Compute batch variance
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True) batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
# Combine variances using parallel algorithm
delta2 = batch_mean - self._mean delta2 = batch_mean - self._mean
m_a = self._var * self.count m_a = self._var * self.count
m_b = batch_var * batch_size m_b = batch_var * batch_size

View File

@ -411,21 +411,24 @@ class IsaacOpenDrawerFrankaArgs(IsaacLabArgs):
@dataclass @dataclass
class IsaacVelocityFlatH1Args(IsaacLabArgs): class IsaacVelocityFlatH1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Flat-H1-v0" env_name: str = "Isaac-Velocity-Flat-H1-v0"
num_steps: int = 3 num_steps: int = 8
num_updates: int = 4
total_timesteps: int = 75000 total_timesteps: int = 75000
@dataclass @dataclass
class IsaacVelocityFlatG1Args(IsaacLabArgs): class IsaacVelocityFlatG1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Flat-G1-v0" env_name: str = "Isaac-Velocity-Flat-G1-v0"
num_steps: int = 3 num_steps: int = 8
num_updates: int = 4
total_timesteps: int = 50000 total_timesteps: int = 50000
@dataclass @dataclass
class IsaacVelocityRoughH1Args(IsaacLabArgs): class IsaacVelocityRoughH1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Rough-H1-v0" env_name: str = "Isaac-Velocity-Rough-H1-v0"
num_steps: int = 3 num_steps: int = 8
num_updates: int = 4
buffer_size: int = 1024 * 5 # To reduce memory usage buffer_size: int = 1024 * 5 # To reduce memory usage
total_timesteps: int = 50000 total_timesteps: int = 50000
@ -433,7 +436,8 @@ class IsaacVelocityRoughH1Args(IsaacLabArgs):
@dataclass @dataclass
class IsaacVelocityRoughG1Args(IsaacLabArgs): class IsaacVelocityRoughG1Args(IsaacLabArgs):
env_name: str = "Isaac-Velocity-Rough-G1-v0" env_name: str = "Isaac-Velocity-Rough-G1-v0"
num_steps: int = 3 num_steps: int = 8
num_updates: int = 4
buffer_size: int = 1024 * 5 # To reduce memory usage buffer_size: int = 1024 * 5 # To reduce memory usage
total_timesteps: int = 50000 total_timesteps: int = 50000

View File

@ -305,6 +305,7 @@ def main():
next_state_actions = (actor(next_observations) + clipped_noise).clamp( next_state_actions = (actor(next_observations) + clipped_noise).clamp(
action_low, action_high action_low, action_high
) )
discount = args.gamma ** data["next"]["effective_n_steps"]
with torch.no_grad(): with torch.no_grad():
qf1_next_target_projected, qf2_next_target_projected = ( qf1_next_target_projected, qf2_next_target_projected = (
@ -313,7 +314,7 @@ def main():
next_state_actions, next_state_actions,
rewards, rewards,
bootstrap, bootstrap,
args.gamma, discount,
) )
) )
qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected) qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)