Fix replay buffer issues when n_steps > 1 (#7)
- Fix an issue where the n-step reward is not properly computed for end-of-episode transitions when using n_step > 1. - Fix an issue where the observation and next_observations are sampled across different episodes when using n_step > 1 and the buffer is full - Fix an issue where the discount is not properly computed when n_step > 1
This commit is contained in:
parent
fe028b578f
commit
85cb1c65c7
10
README.md
10
README.md
@ -8,6 +8,13 @@ FastTD3 is a high-performance variant of the Twin Delayed Deep Deterministic Pol
|
||||
|
||||
For more information, please see our [project webpage](https://younggyo.me/fast_td3)
|
||||
|
||||
|
||||
## ❗ Updates
|
||||
- **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
|
||||
|
||||
- **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
|
||||
|
||||
|
||||
## ✨ Features
|
||||
|
||||
FastTD3 offers researchers a significant speedup in training complex humanoid agents.
|
||||
@ -151,9 +158,6 @@ We would like to thank people who have helped throughout the project:
|
||||
- We thank [Kevin Zakka](https://kzakka.com/) for the help in setting up MuJoCo Playground.
|
||||
- We thank [Changyeon Kim](https://changyeon.site/) for testing the early version of this codebase
|
||||
|
||||
## ❗ Updates
|
||||
- **[Jun/1/2025]** We updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
|
||||
|
||||
## Citations
|
||||
|
||||
### FastTD3
|
||||
|
@ -39,14 +39,17 @@ class DistributionalQNetwork(nn.Module):
|
||||
actions: torch.Tensor,
|
||||
rewards: torch.Tensor,
|
||||
bootstrap: torch.Tensor,
|
||||
gamma: float,
|
||||
discount: float,
|
||||
q_support: torch.Tensor,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
|
||||
batch_size = rewards.shape[0]
|
||||
|
||||
target_z = rewards.unsqueeze(1) + bootstrap.unsqueeze(1) * gamma * q_support
|
||||
target_z = (
|
||||
rewards.unsqueeze(1)
|
||||
+ bootstrap.unsqueeze(1) * discount.unsqueeze(1) * q_support
|
||||
)
|
||||
target_z = target_z.clamp(self.v_min, self.v_max)
|
||||
b = (target_z - self.v_min) / delta_z
|
||||
l = torch.floor(b).long()
|
||||
@ -121,7 +124,7 @@ class Critic(nn.Module):
|
||||
actions: torch.Tensor,
|
||||
rewards: torch.Tensor,
|
||||
bootstrap: torch.Tensor,
|
||||
gamma: float,
|
||||
discount: float,
|
||||
) -> torch.Tensor:
|
||||
"""Projection operation that includes q_support directly"""
|
||||
q1_proj = self.qnet1.projection(
|
||||
@ -129,7 +132,7 @@ class Critic(nn.Module):
|
||||
actions,
|
||||
rewards,
|
||||
bootstrap,
|
||||
gamma,
|
||||
discount,
|
||||
self.q_support,
|
||||
self.q_support.device,
|
||||
)
|
||||
@ -138,7 +141,7 @@ class Critic(nn.Module):
|
||||
actions,
|
||||
rewards,
|
||||
bootstrap,
|
||||
gamma,
|
||||
discount,
|
||||
self.q_support,
|
||||
self.q_support.device,
|
||||
)
|
||||
|
@ -27,6 +27,8 @@ class SimpleReplayBuffer(nn.Module):
|
||||
When playground_mode=True, critic_observations are treated as a concatenation of
|
||||
regular observations and privileged observations, and only the privileged part is stored
|
||||
to save memory.
|
||||
|
||||
TODO (Younggyo): Refactor to split this into SimpleReplayBuffer and NStepReplayBuffer
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -146,6 +148,7 @@ class SimpleReplayBuffer(nn.Module):
|
||||
truncations = torch.gather(self.truncations, 1, indices).reshape(
|
||||
self.n_env * batch_size
|
||||
)
|
||||
effective_n_steps = torch.ones_like(dones)
|
||||
if self.asymmetric_obs:
|
||||
if self.playground_mode:
|
||||
# Gather privileged observations
|
||||
@ -179,12 +182,31 @@ class SimpleReplayBuffer(nn.Module):
|
||||
).reshape(self.n_env * batch_size, self.n_critic_obs)
|
||||
else:
|
||||
# Sample base indices
|
||||
indices = torch.randint(
|
||||
0,
|
||||
min(self.buffer_size, self.ptr),
|
||||
(self.n_env, batch_size),
|
||||
device=self.device,
|
||||
)
|
||||
if self.ptr >= self.buffer_size:
|
||||
# When the buffer is full, there is no protection against sampling across different episodes
|
||||
# We avoid this by temporarily setting self.pos - 1 to truncated = True if not done
|
||||
# https://github.com/DLR-RM/stable-baselines3/blob/b91050ca94f8bce7a0285c91f85da518d5a26223/stable_baselines3/common/buffers.py#L857-L860
|
||||
# TODO (Younggyo): Change the reference when this SB3 branch is merged
|
||||
current_pos = self.ptr % self.buffer_size
|
||||
curr_truncations = self.truncations[:, current_pos - 1].clone()
|
||||
self.truncations[:, current_pos - 1] = torch.logical_not(
|
||||
self.dones[:, current_pos - 1]
|
||||
)
|
||||
indices = torch.randint(
|
||||
0,
|
||||
self.buffer_size,
|
||||
(self.n_env, batch_size),
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# Buffer not full - ensure n-step sequence doesn't exceed valid data
|
||||
max_start_idx = max(1, self.ptr - self.n_steps + 1)
|
||||
indices = torch.randint(
|
||||
0,
|
||||
max_start_idx,
|
||||
(self.n_env, batch_size),
|
||||
device=self.device,
|
||||
)
|
||||
obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
|
||||
act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act)
|
||||
|
||||
@ -239,11 +261,15 @@ class SimpleReplayBuffer(nn.Module):
|
||||
all_indices,
|
||||
)
|
||||
|
||||
# Create masks for rewards after first done
|
||||
# Create masks for rewards *after* first done
|
||||
# This creates a cumulative product that zeroes out rewards after the first done
|
||||
all_dones_shifted = torch.cat(
|
||||
[torch.zeros_like(all_dones[:, :, :1]), all_dones[:, :, :-1]], dim=2
|
||||
) # First reward should not be masked
|
||||
done_masks = torch.cumprod(
|
||||
1.0 - all_dones, dim=2
|
||||
1.0 - all_dones_shifted, dim=2
|
||||
) # [n_env, batch_size, n_step]
|
||||
effective_n_steps = done_masks.sum(2)
|
||||
|
||||
# Create discount factors
|
||||
discounts = torch.pow(
|
||||
@ -339,6 +365,7 @@ class SimpleReplayBuffer(nn.Module):
|
||||
rewards = n_step_rewards.reshape(self.n_env * batch_size)
|
||||
dones = final_dones.reshape(self.n_env * batch_size)
|
||||
truncations = final_truncations.reshape(self.n_env * batch_size)
|
||||
effective_n_steps = effective_n_steps.reshape(self.n_env * batch_size)
|
||||
next_observations = final_next_observations.reshape(
|
||||
self.n_env * batch_size, self.n_obs
|
||||
)
|
||||
@ -352,6 +379,7 @@ class SimpleReplayBuffer(nn.Module):
|
||||
"dones": dones,
|
||||
"truncations": truncations,
|
||||
"observations": next_observations,
|
||||
"effective_n_steps": effective_n_steps,
|
||||
},
|
||||
},
|
||||
batch_size=self.n_env * batch_size,
|
||||
@ -359,6 +387,10 @@ class SimpleReplayBuffer(nn.Module):
|
||||
if self.asymmetric_obs:
|
||||
out["critic_observations"] = critic_observations
|
||||
out["next"]["critic_observations"] = next_critic_observations
|
||||
|
||||
if self.ptr >= self.buffer_size:
|
||||
# Roll back the truncation flags introduced for safe sampling
|
||||
self.truncations[:, current_pos - 1] = curr_truncations
|
||||
return out
|
||||
|
||||
|
||||
@ -406,7 +438,6 @@ class EmpiricalNormalization(nn.Module):
|
||||
|
||||
@torch.jit.unused
|
||||
def update(self, x):
|
||||
"""Learn input values using Welford's online algorithm"""
|
||||
if self.until is not None and self.count >= self.until:
|
||||
return
|
||||
|
||||
@ -420,12 +451,10 @@ class EmpiricalNormalization(nn.Module):
|
||||
delta = batch_mean - self._mean
|
||||
self._mean += (batch_size / new_count) * delta
|
||||
|
||||
# Update variance using Welford's parallel algorithm
|
||||
# Update variance using Chan's parallel algorithm
|
||||
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
if self.count > 0: # Ensure we're not dividing by zero
|
||||
# Compute batch variance
|
||||
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
|
||||
|
||||
# Combine variances using parallel algorithm
|
||||
delta2 = batch_mean - self._mean
|
||||
m_a = self._var * self.count
|
||||
m_b = batch_var * batch_size
|
||||
|
@ -411,21 +411,24 @@ class IsaacOpenDrawerFrankaArgs(IsaacLabArgs):
|
||||
@dataclass
|
||||
class IsaacVelocityFlatH1Args(IsaacLabArgs):
|
||||
env_name: str = "Isaac-Velocity-Flat-H1-v0"
|
||||
num_steps: int = 3
|
||||
num_steps: int = 8
|
||||
num_updates: int = 4
|
||||
total_timesteps: int = 75000
|
||||
|
||||
|
||||
@dataclass
|
||||
class IsaacVelocityFlatG1Args(IsaacLabArgs):
|
||||
env_name: str = "Isaac-Velocity-Flat-G1-v0"
|
||||
num_steps: int = 3
|
||||
num_steps: int = 8
|
||||
num_updates: int = 4
|
||||
total_timesteps: int = 50000
|
||||
|
||||
|
||||
@dataclass
|
||||
class IsaacVelocityRoughH1Args(IsaacLabArgs):
|
||||
env_name: str = "Isaac-Velocity-Rough-H1-v0"
|
||||
num_steps: int = 3
|
||||
num_steps: int = 8
|
||||
num_updates: int = 4
|
||||
buffer_size: int = 1024 * 5 # To reduce memory usage
|
||||
total_timesteps: int = 50000
|
||||
|
||||
@ -433,7 +436,8 @@ class IsaacVelocityRoughH1Args(IsaacLabArgs):
|
||||
@dataclass
|
||||
class IsaacVelocityRoughG1Args(IsaacLabArgs):
|
||||
env_name: str = "Isaac-Velocity-Rough-G1-v0"
|
||||
num_steps: int = 3
|
||||
num_steps: int = 8
|
||||
num_updates: int = 4
|
||||
buffer_size: int = 1024 * 5 # To reduce memory usage
|
||||
total_timesteps: int = 50000
|
||||
|
||||
|
@ -305,6 +305,7 @@ def main():
|
||||
next_state_actions = (actor(next_observations) + clipped_noise).clamp(
|
||||
action_low, action_high
|
||||
)
|
||||
discount = args.gamma ** data["next"]["effective_n_steps"]
|
||||
|
||||
with torch.no_grad():
|
||||
qf1_next_target_projected, qf2_next_target_projected = (
|
||||
@ -313,7 +314,7 @@ def main():
|
||||
next_state_actions,
|
||||
rewards,
|
||||
bootstrap,
|
||||
args.gamma,
|
||||
discount,
|
||||
)
|
||||
)
|
||||
qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)
|
||||
|
Loading…
Reference in New Issue
Block a user