From 85cb1c65c7d079f8319c7cf5c4612e0b87909c1b Mon Sep 17 00:00:00 2001 From: Younggyo Seo Date: Sat, 7 Jun 2025 01:20:48 -0400 Subject: [PATCH] Fix replay buffer issues when n_steps > 1 (#7) - Fix an issue where the n-step reward is not properly computed for end-of-episode transitions when using n_step > 1. - Fix an issue where the observation and next_observations are sampled across different episodes when using n_step > 1 and the buffer is full - Fix an issue where the discount is not properly computed when n_step > 1 --- README.md | 10 ++++--- fast_td3/fast_td3.py | 13 +++++---- fast_td3/fast_td3_utils.py | 55 +++++++++++++++++++++++++++++--------- fast_td3/hyperparams.py | 12 ++++++--- fast_td3/train.py | 3 ++- 5 files changed, 67 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index b2cab18..0a75b8f 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,13 @@ FastTD3 is a high-performance variant of the Twin Delayed Deep Deterministic Pol For more information, please see our [project webpage](https://younggyo.me/fast_td3) + +## ❗ Updates +- **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot! + +- **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks. + + ## ✨ Features FastTD3 offers researchers a significant speedup in training complex humanoid agents. @@ -151,9 +158,6 @@ We would like to thank people who have helped throughout the project: - We thank [Kevin Zakka](https://kzakka.com/) for the help in setting up MuJoCo Playground. - We thank [Changyeon Kim](https://changyeon.site/) for testing the early version of this codebase -## ❗ Updates -- **[Jun/1/2025]** We updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks. - ## Citations ### FastTD3 diff --git a/fast_td3/fast_td3.py b/fast_td3/fast_td3.py index 2a149bb..4805ccc 100644 --- a/fast_td3/fast_td3.py +++ b/fast_td3/fast_td3.py @@ -39,14 +39,17 @@ class DistributionalQNetwork(nn.Module): actions: torch.Tensor, rewards: torch.Tensor, bootstrap: torch.Tensor, - gamma: float, + discount: float, q_support: torch.Tensor, device: torch.device, ) -> torch.Tensor: delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) batch_size = rewards.shape[0] - target_z = rewards.unsqueeze(1) + bootstrap.unsqueeze(1) * gamma * q_support + target_z = ( + rewards.unsqueeze(1) + + bootstrap.unsqueeze(1) * discount.unsqueeze(1) * q_support + ) target_z = target_z.clamp(self.v_min, self.v_max) b = (target_z - self.v_min) / delta_z l = torch.floor(b).long() @@ -121,7 +124,7 @@ class Critic(nn.Module): actions: torch.Tensor, rewards: torch.Tensor, bootstrap: torch.Tensor, - gamma: float, + discount: float, ) -> torch.Tensor: """Projection operation that includes q_support directly""" q1_proj = self.qnet1.projection( @@ -129,7 +132,7 @@ class Critic(nn.Module): actions, rewards, bootstrap, - gamma, + discount, self.q_support, self.q_support.device, ) @@ -138,7 +141,7 @@ class Critic(nn.Module): actions, rewards, bootstrap, - gamma, + discount, self.q_support, self.q_support.device, ) diff --git a/fast_td3/fast_td3_utils.py b/fast_td3/fast_td3_utils.py index 46346a2..6923440 100644 --- a/fast_td3/fast_td3_utils.py +++ b/fast_td3/fast_td3_utils.py @@ -27,6 +27,8 @@ class SimpleReplayBuffer(nn.Module): When playground_mode=True, critic_observations are treated as a concatenation of regular observations and privileged observations, and only the privileged part is stored to save memory. + + TODO (Younggyo): Refactor to split this into SimpleReplayBuffer and NStepReplayBuffer """ super().__init__() @@ -146,6 +148,7 @@ class SimpleReplayBuffer(nn.Module): truncations = torch.gather(self.truncations, 1, indices).reshape( self.n_env * batch_size ) + effective_n_steps = torch.ones_like(dones) if self.asymmetric_obs: if self.playground_mode: # Gather privileged observations @@ -179,12 +182,31 @@ class SimpleReplayBuffer(nn.Module): ).reshape(self.n_env * batch_size, self.n_critic_obs) else: # Sample base indices - indices = torch.randint( - 0, - min(self.buffer_size, self.ptr), - (self.n_env, batch_size), - device=self.device, - ) + if self.ptr >= self.buffer_size: + # When the buffer is full, there is no protection against sampling across different episodes + # We avoid this by temporarily setting self.pos - 1 to truncated = True if not done + # https://github.com/DLR-RM/stable-baselines3/blob/b91050ca94f8bce7a0285c91f85da518d5a26223/stable_baselines3/common/buffers.py#L857-L860 + # TODO (Younggyo): Change the reference when this SB3 branch is merged + current_pos = self.ptr % self.buffer_size + curr_truncations = self.truncations[:, current_pos - 1].clone() + self.truncations[:, current_pos - 1] = torch.logical_not( + self.dones[:, current_pos - 1] + ) + indices = torch.randint( + 0, + self.buffer_size, + (self.n_env, batch_size), + device=self.device, + ) + else: + # Buffer not full - ensure n-step sequence doesn't exceed valid data + max_start_idx = max(1, self.ptr - self.n_steps + 1) + indices = torch.randint( + 0, + max_start_idx, + (self.n_env, batch_size), + device=self.device, + ) obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs) act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act) @@ -239,11 +261,15 @@ class SimpleReplayBuffer(nn.Module): all_indices, ) - # Create masks for rewards after first done + # Create masks for rewards *after* first done # This creates a cumulative product that zeroes out rewards after the first done + all_dones_shifted = torch.cat( + [torch.zeros_like(all_dones[:, :, :1]), all_dones[:, :, :-1]], dim=2 + ) # First reward should not be masked done_masks = torch.cumprod( - 1.0 - all_dones, dim=2 + 1.0 - all_dones_shifted, dim=2 ) # [n_env, batch_size, n_step] + effective_n_steps = done_masks.sum(2) # Create discount factors discounts = torch.pow( @@ -339,6 +365,7 @@ class SimpleReplayBuffer(nn.Module): rewards = n_step_rewards.reshape(self.n_env * batch_size) dones = final_dones.reshape(self.n_env * batch_size) truncations = final_truncations.reshape(self.n_env * batch_size) + effective_n_steps = effective_n_steps.reshape(self.n_env * batch_size) next_observations = final_next_observations.reshape( self.n_env * batch_size, self.n_obs ) @@ -352,6 +379,7 @@ class SimpleReplayBuffer(nn.Module): "dones": dones, "truncations": truncations, "observations": next_observations, + "effective_n_steps": effective_n_steps, }, }, batch_size=self.n_env * batch_size, @@ -359,6 +387,10 @@ class SimpleReplayBuffer(nn.Module): if self.asymmetric_obs: out["critic_observations"] = critic_observations out["next"]["critic_observations"] = next_critic_observations + + if self.ptr >= self.buffer_size: + # Roll back the truncation flags introduced for safe sampling + self.truncations[:, current_pos - 1] = curr_truncations return out @@ -406,7 +438,6 @@ class EmpiricalNormalization(nn.Module): @torch.jit.unused def update(self, x): - """Learn input values using Welford's online algorithm""" if self.until is not None and self.count >= self.until: return @@ -420,12 +451,10 @@ class EmpiricalNormalization(nn.Module): delta = batch_mean - self._mean self._mean += (batch_size / new_count) * delta - # Update variance using Welford's parallel algorithm + # Update variance using Chan's parallel algorithm + # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm if self.count > 0: # Ensure we're not dividing by zero - # Compute batch variance batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True) - - # Combine variances using parallel algorithm delta2 = batch_mean - self._mean m_a = self._var * self.count m_b = batch_var * batch_size diff --git a/fast_td3/hyperparams.py b/fast_td3/hyperparams.py index 3f1a8ed..c1fa4df 100644 --- a/fast_td3/hyperparams.py +++ b/fast_td3/hyperparams.py @@ -411,21 +411,24 @@ class IsaacOpenDrawerFrankaArgs(IsaacLabArgs): @dataclass class IsaacVelocityFlatH1Args(IsaacLabArgs): env_name: str = "Isaac-Velocity-Flat-H1-v0" - num_steps: int = 3 + num_steps: int = 8 + num_updates: int = 4 total_timesteps: int = 75000 @dataclass class IsaacVelocityFlatG1Args(IsaacLabArgs): env_name: str = "Isaac-Velocity-Flat-G1-v0" - num_steps: int = 3 + num_steps: int = 8 + num_updates: int = 4 total_timesteps: int = 50000 @dataclass class IsaacVelocityRoughH1Args(IsaacLabArgs): env_name: str = "Isaac-Velocity-Rough-H1-v0" - num_steps: int = 3 + num_steps: int = 8 + num_updates: int = 4 buffer_size: int = 1024 * 5 # To reduce memory usage total_timesteps: int = 50000 @@ -433,7 +436,8 @@ class IsaacVelocityRoughH1Args(IsaacLabArgs): @dataclass class IsaacVelocityRoughG1Args(IsaacLabArgs): env_name: str = "Isaac-Velocity-Rough-G1-v0" - num_steps: int = 3 + num_steps: int = 8 + num_updates: int = 4 buffer_size: int = 1024 * 5 # To reduce memory usage total_timesteps: int = 50000 diff --git a/fast_td3/train.py b/fast_td3/train.py index c480403..672ab79 100644 --- a/fast_td3/train.py +++ b/fast_td3/train.py @@ -305,6 +305,7 @@ def main(): next_state_actions = (actor(next_observations) + clipped_noise).clamp( action_low, action_high ) + discount = args.gamma ** data["next"]["effective_n_steps"] with torch.no_grad(): qf1_next_target_projected, qf2_next_target_projected = ( @@ -313,7 +314,7 @@ def main(): next_state_actions, rewards, bootstrap, - args.gamma, + discount, ) ) qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)