Fix replay buffer issues when n_steps > 1 (#7)
- Fix an issue where the n-step reward is not properly computed for end-of-episode transitions when using n_step > 1. - Fix an issue where the observation and next_observations are sampled across different episodes when using n_step > 1 and the buffer is full - Fix an issue where the discount is not properly computed when n_step > 1
This commit is contained in:
parent
fe028b578f
commit
85cb1c65c7
10
README.md
10
README.md
@ -8,6 +8,13 @@ FastTD3 is a high-performance variant of the Twin Delayed Deep Deterministic Pol
|
|||||||
|
|
||||||
For more information, please see our [project webpage](https://younggyo.me/fast_td3)
|
For more information, please see our [project webpage](https://younggyo.me/fast_td3)
|
||||||
|
|
||||||
|
|
||||||
|
## ❗ Updates
|
||||||
|
- **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
|
||||||
|
|
||||||
|
- **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
|
||||||
|
|
||||||
|
|
||||||
## ✨ Features
|
## ✨ Features
|
||||||
|
|
||||||
FastTD3 offers researchers a significant speedup in training complex humanoid agents.
|
FastTD3 offers researchers a significant speedup in training complex humanoid agents.
|
||||||
@ -151,9 +158,6 @@ We would like to thank people who have helped throughout the project:
|
|||||||
- We thank [Kevin Zakka](https://kzakka.com/) for the help in setting up MuJoCo Playground.
|
- We thank [Kevin Zakka](https://kzakka.com/) for the help in setting up MuJoCo Playground.
|
||||||
- We thank [Changyeon Kim](https://changyeon.site/) for testing the early version of this codebase
|
- We thank [Changyeon Kim](https://changyeon.site/) for testing the early version of this codebase
|
||||||
|
|
||||||
## ❗ Updates
|
|
||||||
- **[Jun/1/2025]** We updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
|
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
### FastTD3
|
### FastTD3
|
||||||
|
@ -39,14 +39,17 @@ class DistributionalQNetwork(nn.Module):
|
|||||||
actions: torch.Tensor,
|
actions: torch.Tensor,
|
||||||
rewards: torch.Tensor,
|
rewards: torch.Tensor,
|
||||||
bootstrap: torch.Tensor,
|
bootstrap: torch.Tensor,
|
||||||
gamma: float,
|
discount: float,
|
||||||
q_support: torch.Tensor,
|
q_support: torch.Tensor,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
|
delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
|
||||||
batch_size = rewards.shape[0]
|
batch_size = rewards.shape[0]
|
||||||
|
|
||||||
target_z = rewards.unsqueeze(1) + bootstrap.unsqueeze(1) * gamma * q_support
|
target_z = (
|
||||||
|
rewards.unsqueeze(1)
|
||||||
|
+ bootstrap.unsqueeze(1) * discount.unsqueeze(1) * q_support
|
||||||
|
)
|
||||||
target_z = target_z.clamp(self.v_min, self.v_max)
|
target_z = target_z.clamp(self.v_min, self.v_max)
|
||||||
b = (target_z - self.v_min) / delta_z
|
b = (target_z - self.v_min) / delta_z
|
||||||
l = torch.floor(b).long()
|
l = torch.floor(b).long()
|
||||||
@ -121,7 +124,7 @@ class Critic(nn.Module):
|
|||||||
actions: torch.Tensor,
|
actions: torch.Tensor,
|
||||||
rewards: torch.Tensor,
|
rewards: torch.Tensor,
|
||||||
bootstrap: torch.Tensor,
|
bootstrap: torch.Tensor,
|
||||||
gamma: float,
|
discount: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Projection operation that includes q_support directly"""
|
"""Projection operation that includes q_support directly"""
|
||||||
q1_proj = self.qnet1.projection(
|
q1_proj = self.qnet1.projection(
|
||||||
@ -129,7 +132,7 @@ class Critic(nn.Module):
|
|||||||
actions,
|
actions,
|
||||||
rewards,
|
rewards,
|
||||||
bootstrap,
|
bootstrap,
|
||||||
gamma,
|
discount,
|
||||||
self.q_support,
|
self.q_support,
|
||||||
self.q_support.device,
|
self.q_support.device,
|
||||||
)
|
)
|
||||||
@ -138,7 +141,7 @@ class Critic(nn.Module):
|
|||||||
actions,
|
actions,
|
||||||
rewards,
|
rewards,
|
||||||
bootstrap,
|
bootstrap,
|
||||||
gamma,
|
discount,
|
||||||
self.q_support,
|
self.q_support,
|
||||||
self.q_support.device,
|
self.q_support.device,
|
||||||
)
|
)
|
||||||
|
@ -27,6 +27,8 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
When playground_mode=True, critic_observations are treated as a concatenation of
|
When playground_mode=True, critic_observations are treated as a concatenation of
|
||||||
regular observations and privileged observations, and only the privileged part is stored
|
regular observations and privileged observations, and only the privileged part is stored
|
||||||
to save memory.
|
to save memory.
|
||||||
|
|
||||||
|
TODO (Younggyo): Refactor to split this into SimpleReplayBuffer and NStepReplayBuffer
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -146,6 +148,7 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
truncations = torch.gather(self.truncations, 1, indices).reshape(
|
truncations = torch.gather(self.truncations, 1, indices).reshape(
|
||||||
self.n_env * batch_size
|
self.n_env * batch_size
|
||||||
)
|
)
|
||||||
|
effective_n_steps = torch.ones_like(dones)
|
||||||
if self.asymmetric_obs:
|
if self.asymmetric_obs:
|
||||||
if self.playground_mode:
|
if self.playground_mode:
|
||||||
# Gather privileged observations
|
# Gather privileged observations
|
||||||
@ -179,12 +182,31 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
).reshape(self.n_env * batch_size, self.n_critic_obs)
|
).reshape(self.n_env * batch_size, self.n_critic_obs)
|
||||||
else:
|
else:
|
||||||
# Sample base indices
|
# Sample base indices
|
||||||
indices = torch.randint(
|
if self.ptr >= self.buffer_size:
|
||||||
0,
|
# When the buffer is full, there is no protection against sampling across different episodes
|
||||||
min(self.buffer_size, self.ptr),
|
# We avoid this by temporarily setting self.pos - 1 to truncated = True if not done
|
||||||
(self.n_env, batch_size),
|
# https://github.com/DLR-RM/stable-baselines3/blob/b91050ca94f8bce7a0285c91f85da518d5a26223/stable_baselines3/common/buffers.py#L857-L860
|
||||||
device=self.device,
|
# TODO (Younggyo): Change the reference when this SB3 branch is merged
|
||||||
)
|
current_pos = self.ptr % self.buffer_size
|
||||||
|
curr_truncations = self.truncations[:, current_pos - 1].clone()
|
||||||
|
self.truncations[:, current_pos - 1] = torch.logical_not(
|
||||||
|
self.dones[:, current_pos - 1]
|
||||||
|
)
|
||||||
|
indices = torch.randint(
|
||||||
|
0,
|
||||||
|
self.buffer_size,
|
||||||
|
(self.n_env, batch_size),
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Buffer not full - ensure n-step sequence doesn't exceed valid data
|
||||||
|
max_start_idx = max(1, self.ptr - self.n_steps + 1)
|
||||||
|
indices = torch.randint(
|
||||||
|
0,
|
||||||
|
max_start_idx,
|
||||||
|
(self.n_env, batch_size),
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
|
obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
|
||||||
act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act)
|
act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act)
|
||||||
|
|
||||||
@ -239,11 +261,15 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
all_indices,
|
all_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create masks for rewards after first done
|
# Create masks for rewards *after* first done
|
||||||
# This creates a cumulative product that zeroes out rewards after the first done
|
# This creates a cumulative product that zeroes out rewards after the first done
|
||||||
|
all_dones_shifted = torch.cat(
|
||||||
|
[torch.zeros_like(all_dones[:, :, :1]), all_dones[:, :, :-1]], dim=2
|
||||||
|
) # First reward should not be masked
|
||||||
done_masks = torch.cumprod(
|
done_masks = torch.cumprod(
|
||||||
1.0 - all_dones, dim=2
|
1.0 - all_dones_shifted, dim=2
|
||||||
) # [n_env, batch_size, n_step]
|
) # [n_env, batch_size, n_step]
|
||||||
|
effective_n_steps = done_masks.sum(2)
|
||||||
|
|
||||||
# Create discount factors
|
# Create discount factors
|
||||||
discounts = torch.pow(
|
discounts = torch.pow(
|
||||||
@ -339,6 +365,7 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
rewards = n_step_rewards.reshape(self.n_env * batch_size)
|
rewards = n_step_rewards.reshape(self.n_env * batch_size)
|
||||||
dones = final_dones.reshape(self.n_env * batch_size)
|
dones = final_dones.reshape(self.n_env * batch_size)
|
||||||
truncations = final_truncations.reshape(self.n_env * batch_size)
|
truncations = final_truncations.reshape(self.n_env * batch_size)
|
||||||
|
effective_n_steps = effective_n_steps.reshape(self.n_env * batch_size)
|
||||||
next_observations = final_next_observations.reshape(
|
next_observations = final_next_observations.reshape(
|
||||||
self.n_env * batch_size, self.n_obs
|
self.n_env * batch_size, self.n_obs
|
||||||
)
|
)
|
||||||
@ -352,6 +379,7 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
"dones": dones,
|
"dones": dones,
|
||||||
"truncations": truncations,
|
"truncations": truncations,
|
||||||
"observations": next_observations,
|
"observations": next_observations,
|
||||||
|
"effective_n_steps": effective_n_steps,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
batch_size=self.n_env * batch_size,
|
batch_size=self.n_env * batch_size,
|
||||||
@ -359,6 +387,10 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
if self.asymmetric_obs:
|
if self.asymmetric_obs:
|
||||||
out["critic_observations"] = critic_observations
|
out["critic_observations"] = critic_observations
|
||||||
out["next"]["critic_observations"] = next_critic_observations
|
out["next"]["critic_observations"] = next_critic_observations
|
||||||
|
|
||||||
|
if self.ptr >= self.buffer_size:
|
||||||
|
# Roll back the truncation flags introduced for safe sampling
|
||||||
|
self.truncations[:, current_pos - 1] = curr_truncations
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -406,7 +438,6 @@ class EmpiricalNormalization(nn.Module):
|
|||||||
|
|
||||||
@torch.jit.unused
|
@torch.jit.unused
|
||||||
def update(self, x):
|
def update(self, x):
|
||||||
"""Learn input values using Welford's online algorithm"""
|
|
||||||
if self.until is not None and self.count >= self.until:
|
if self.until is not None and self.count >= self.until:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -420,12 +451,10 @@ class EmpiricalNormalization(nn.Module):
|
|||||||
delta = batch_mean - self._mean
|
delta = batch_mean - self._mean
|
||||||
self._mean += (batch_size / new_count) * delta
|
self._mean += (batch_size / new_count) * delta
|
||||||
|
|
||||||
# Update variance using Welford's parallel algorithm
|
# Update variance using Chan's parallel algorithm
|
||||||
|
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||||
if self.count > 0: # Ensure we're not dividing by zero
|
if self.count > 0: # Ensure we're not dividing by zero
|
||||||
# Compute batch variance
|
|
||||||
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
|
batch_var = torch.mean((x - batch_mean) ** 2, dim=0, keepdim=True)
|
||||||
|
|
||||||
# Combine variances using parallel algorithm
|
|
||||||
delta2 = batch_mean - self._mean
|
delta2 = batch_mean - self._mean
|
||||||
m_a = self._var * self.count
|
m_a = self._var * self.count
|
||||||
m_b = batch_var * batch_size
|
m_b = batch_var * batch_size
|
||||||
|
@ -411,21 +411,24 @@ class IsaacOpenDrawerFrankaArgs(IsaacLabArgs):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class IsaacVelocityFlatH1Args(IsaacLabArgs):
|
class IsaacVelocityFlatH1Args(IsaacLabArgs):
|
||||||
env_name: str = "Isaac-Velocity-Flat-H1-v0"
|
env_name: str = "Isaac-Velocity-Flat-H1-v0"
|
||||||
num_steps: int = 3
|
num_steps: int = 8
|
||||||
|
num_updates: int = 4
|
||||||
total_timesteps: int = 75000
|
total_timesteps: int = 75000
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IsaacVelocityFlatG1Args(IsaacLabArgs):
|
class IsaacVelocityFlatG1Args(IsaacLabArgs):
|
||||||
env_name: str = "Isaac-Velocity-Flat-G1-v0"
|
env_name: str = "Isaac-Velocity-Flat-G1-v0"
|
||||||
num_steps: int = 3
|
num_steps: int = 8
|
||||||
|
num_updates: int = 4
|
||||||
total_timesteps: int = 50000
|
total_timesteps: int = 50000
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IsaacVelocityRoughH1Args(IsaacLabArgs):
|
class IsaacVelocityRoughH1Args(IsaacLabArgs):
|
||||||
env_name: str = "Isaac-Velocity-Rough-H1-v0"
|
env_name: str = "Isaac-Velocity-Rough-H1-v0"
|
||||||
num_steps: int = 3
|
num_steps: int = 8
|
||||||
|
num_updates: int = 4
|
||||||
buffer_size: int = 1024 * 5 # To reduce memory usage
|
buffer_size: int = 1024 * 5 # To reduce memory usage
|
||||||
total_timesteps: int = 50000
|
total_timesteps: int = 50000
|
||||||
|
|
||||||
@ -433,7 +436,8 @@ class IsaacVelocityRoughH1Args(IsaacLabArgs):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class IsaacVelocityRoughG1Args(IsaacLabArgs):
|
class IsaacVelocityRoughG1Args(IsaacLabArgs):
|
||||||
env_name: str = "Isaac-Velocity-Rough-G1-v0"
|
env_name: str = "Isaac-Velocity-Rough-G1-v0"
|
||||||
num_steps: int = 3
|
num_steps: int = 8
|
||||||
|
num_updates: int = 4
|
||||||
buffer_size: int = 1024 * 5 # To reduce memory usage
|
buffer_size: int = 1024 * 5 # To reduce memory usage
|
||||||
total_timesteps: int = 50000
|
total_timesteps: int = 50000
|
||||||
|
|
||||||
|
@ -305,6 +305,7 @@ def main():
|
|||||||
next_state_actions = (actor(next_observations) + clipped_noise).clamp(
|
next_state_actions = (actor(next_observations) + clipped_noise).clamp(
|
||||||
action_low, action_high
|
action_low, action_high
|
||||||
)
|
)
|
||||||
|
discount = args.gamma ** data["next"]["effective_n_steps"]
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
qf1_next_target_projected, qf2_next_target_projected = (
|
qf1_next_target_projected, qf2_next_target_projected = (
|
||||||
@ -313,7 +314,7 @@ def main():
|
|||||||
next_state_actions,
|
next_state_actions,
|
||||||
rewards,
|
rewards,
|
||||||
bootstrap,
|
bootstrap,
|
||||||
args.gamma,
|
discount,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)
|
qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)
|
||||||
|
Loading…
Reference in New Issue
Block a user