diff --git a/README.md b/README.md index 41a56b4..61d64bf 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,8 @@ python fast_td3/train.py --env_name h1hand-hurdle-v0 --exp_name FastTD3 --render ### MuJoCo Playground Experiments ```bash conda activate fasttd3_playground -python fast_td3/train.py --env_name G1JoystickRoughTerrain --exp_name FastTD3 --render_interval 5000 --seed 1 +python fast_td3/train.py --env_name T1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --seed 1 +python fast_td3/train.py --env_name G1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --seed 1 ``` ### IsaacLab Experiments diff --git a/fast_td3/fast_td3_utils.py b/fast_td3/fast_td3_utils.py index ac0b44e..299110b 100644 --- a/fast_td3/fast_td3_utils.py +++ b/fast_td3/fast_td3_utils.py @@ -15,6 +15,7 @@ class SimpleReplayBuffer(nn.Module): n_act: int, n_critic_obs: int, asymmetric_obs: bool = False, + playground_mode: bool = False, n_steps: int = 1, gamma: float = 0.99, device=None, @@ -22,6 +23,10 @@ class SimpleReplayBuffer(nn.Module): """ A simple replay buffer that stores transitions in a circular buffer. Supports n-step returns and asymmetric observations. + + When playground_mode=True, critic_observations are treated as a concatenation of + regular observations and privileged observations, and only the privileged part is stored + to save memory. """ super().__init__() @@ -31,6 +36,7 @@ class SimpleReplayBuffer(nn.Module): self.n_act = n_act self.n_critic_obs = n_critic_obs self.asymmetric_obs = asymmetric_obs + self.playground_mode = playground_mode and asymmetric_obs self.gamma = gamma self.n_steps = n_steps self.device = device @@ -52,12 +58,23 @@ class SimpleReplayBuffer(nn.Module): (n_env, buffer_size, n_obs), device=device, dtype=torch.float ) if asymmetric_obs: - self.critic_observations = torch.zeros( - (n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float - ) - self.next_critic_observations = torch.zeros( - (n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float - ) + if self.playground_mode: + # Only store the privileged part of observations (n_critic_obs - n_obs) + self.privileged_obs_size = n_critic_obs - n_obs + self.privileged_observations = torch.zeros( + (n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float + ) + self.next_privileged_observations = torch.zeros( + (n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float + ) + else: + # Store full critic observations + self.critic_observations = torch.zeros( + (n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float + ) + self.next_critic_observations = torch.zeros( + (n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float + ) self.ptr = 0 def extend( @@ -80,9 +97,18 @@ class SimpleReplayBuffer(nn.Module): self.next_observations[:, ptr] = next_observations if self.asymmetric_obs: critic_observations = tensor_dict["critic_observations"] - self.critic_observations[:, ptr] = critic_observations next_critic_observations = tensor_dict["next"]["critic_observations"] - self.next_critic_observations[:, ptr] = next_critic_observations + + if self.playground_mode: + # Extract and store only the privileged part + privileged_observations = critic_observations[:, self.n_obs:] + next_privileged_observations = next_critic_observations[:, self.n_obs:] + self.privileged_observations[:, ptr] = privileged_observations + self.next_privileged_observations[:, ptr] = next_privileged_observations + else: + # Store full critic observations + self.critic_observations[:, ptr] = critic_observations + self.next_critic_observations[:, ptr] = next_critic_observations self.ptr += 1 def sample(self, batch_size: int): @@ -117,15 +143,30 @@ class SimpleReplayBuffer(nn.Module): self.n_env * batch_size ) if self.asymmetric_obs: - critic_obs_indices = indices.unsqueeze(-1).expand( - -1, -1, self.n_critic_obs - ) - critic_observations = torch.gather( - self.critic_observations, 1, critic_obs_indices - ).reshape(self.n_env * batch_size, self.n_critic_obs) - next_critic_observations = torch.gather( - self.next_critic_observations, 1, critic_obs_indices - ).reshape(self.n_env * batch_size, self.n_critic_obs) + if self.playground_mode: + # Gather privileged observations + priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size) + privileged_observations = torch.gather( + self.privileged_observations, 1, priv_obs_indices + ).reshape(self.n_env * batch_size, self.privileged_obs_size) + next_privileged_observations = torch.gather( + self.next_privileged_observations, 1, priv_obs_indices + ).reshape(self.n_env * batch_size, self.privileged_obs_size) + + # Concatenate with regular observations to form full critic observations + critic_observations = torch.cat([observations, privileged_observations], dim=1) + next_critic_observations = torch.cat([next_observations, next_privileged_observations], dim=1) + else: + # Gather full critic observations + critic_obs_indices = indices.unsqueeze(-1).expand( + -1, -1, self.n_critic_obs + ) + critic_observations = torch.gather( + self.critic_observations, 1, critic_obs_indices + ).reshape(self.n_env * batch_size, self.n_critic_obs) + next_critic_observations = torch.gather( + self.next_critic_observations, 1, critic_obs_indices + ).reshape(self.n_env * batch_size, self.n_critic_obs) else: # Sample base indices indices = torch.randint( @@ -145,12 +186,23 @@ class SimpleReplayBuffer(nn.Module): self.n_env * batch_size, self.n_act ) if self.asymmetric_obs: - critic_obs_indices = indices.unsqueeze(-1).expand( - -1, -1, self.n_critic_obs - ) - critic_observations = torch.gather( - self.critic_observations, 1, critic_obs_indices - ).reshape(self.n_env * batch_size, self.n_critic_obs) + if self.playground_mode: + # Gather privileged observations + priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size) + privileged_observations = torch.gather( + self.privileged_observations, 1, priv_obs_indices + ).reshape(self.n_env * batch_size, self.privileged_obs_size) + + # Concatenate with regular observations to form full critic observations + critic_observations = torch.cat([observations, privileged_observations], dim=1) + else: + # Gather full critic observations + critic_obs_indices = indices.unsqueeze(-1).expand( + -1, -1, self.n_critic_obs + ) + critic_observations = torch.gather( + self.critic_observations, 1, critic_obs_indices + ).reshape(self.n_env * batch_size, self.n_critic_obs) # Create sequential indices for each sample # This creates a [n_env, batch_size, n_step] tensor of indices @@ -229,14 +281,40 @@ class SimpleReplayBuffer(nn.Module): final_truncations = self.truncations.gather(1, final_next_obs_indices) if self.asymmetric_obs: - final_next_critic_observations = self.next_critic_observations.gather( - 1, - final_next_obs_indices.unsqueeze(-1).expand( - -1, -1, self.n_critic_obs - ), - ) + if self.playground_mode: + # Gather final privileged observations + final_next_privileged_observations = self.next_privileged_observations.gather( + 1, + final_next_obs_indices.unsqueeze(-1).expand( + -1, -1, self.privileged_obs_size + ), + ) + + # Reshape for output + next_privileged_observations = final_next_privileged_observations.reshape( + self.n_env * batch_size, self.privileged_obs_size + ) + + # Concatenate with next observations to form full next critic observations + next_observations_reshaped = final_next_observations.reshape( + self.n_env * batch_size, self.n_obs + ) + next_critic_observations = torch.cat( + [next_observations_reshaped, next_privileged_observations], dim=1 + ) + else: + # Gather final next critic observations directly + final_next_critic_observations = self.next_critic_observations.gather( + 1, + final_next_obs_indices.unsqueeze(-1).expand( + -1, -1, self.n_critic_obs + ), + ) + next_critic_observations = final_next_critic_observations.reshape( + self.n_env * batch_size, self.n_critic_obs + ) + # Reshape everything to batch dimension - rewards = n_step_rewards.reshape(self.n_env * batch_size) dones = final_dones.reshape(self.n_env * batch_size) truncations = final_truncations.reshape(self.n_env * batch_size) @@ -244,11 +322,6 @@ class SimpleReplayBuffer(nn.Module): self.n_env * batch_size, self.n_obs ) - if self.asymmetric_obs: - next_critic_observations = final_next_critic_observations.reshape( - self.n_env * batch_size, self.n_critic_obs - ) - out = TensorDict( { "observations": observations, diff --git a/fast_td3/train.py b/fast_td3/train.py index 98a52f2..c480403 100644 --- a/fast_td3/train.py +++ b/fast_td3/train.py @@ -193,6 +193,7 @@ def main(): n_act=n_act, n_critic_obs=n_critic_obs, asymmetric_obs=envs.asymmetric_obs, + playground_mode=env_type == "mujoco_playground", n_steps=args.num_steps, gamma=args.gamma, device=device, diff --git a/fast_td3/training_notebook.ipynb b/fast_td3/training_notebook.ipynb index 1653660..ec256ab 100644 --- a/fast_td3/training_notebook.ipynb +++ b/fast_td3/training_notebook.ipynb @@ -272,6 +272,7 @@ " n_act=n_act,\n", " n_critic_obs=n_critic_obs,\n", " asymmetric_obs=envs.asymmetric_obs,\n", + " playground_mode=env_type == \"mujoco_playground\",\n", " n_steps=args.num_steps,\n", " gamma=args.gamma,\n", " device=device,\n",