diff --git a/fast_td3/environments/mujoco_playground_env.py b/fast_td3/environments/mujoco_playground_env.py index 5e21999..4aa785b 100644 --- a/fast_td3/environments/mujoco_playground_env.py +++ b/fast_td3/environments/mujoco_playground_env.py @@ -77,9 +77,17 @@ def make_env( ): # Make training environment train_env_cfg = registry.get_default_config(env_name) - if use_tuned_reward: + is_humanoid_task = env_name in [ + "G1JoystickRoughTerrain", + "G1JoystickFlatTerrain", + "T1JoystickRoughTerrain", + "T1JoystickFlatTerrain", + ] + + if use_tuned_reward and is_humanoid_task: # NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper. - assert env_name in ["G1JoystickRoughTerrain", "G1JoystickFlatTerrain"] + # Somehow it works reasonably for T1 as well. + # However, see `sim2real.md` for sim-to-real RL with Booster T1 train_env_cfg.reward_config.scales.energy = -5e-5 train_env_cfg.reward_config.scales.action_rate = -1e-1 train_env_cfg.reward_config.scales.torques = -1e-3 @@ -90,13 +98,6 @@ def make_env( train_env_cfg.reward_config.scales.ang_vel_xy = -0.3 train_env_cfg.reward_config.scales.orientation = -5.0 - is_humanoid_task = env_name in [ - "G1JoystickRoughTerrain", - "G1JoystickFlatTerrain", - "T1JoystickRoughTerrain", - "T1JoystickFlatTerrain", - ] - if is_humanoid_task and not use_push_randomization: train_env_cfg.push_config.enable = False train_env_cfg.push_config.magnitude_range = [0.0, 0.0] diff --git a/fast_td3/fast_td3_utils.py b/fast_td3/fast_td3_utils.py index 299110b..46346a2 100644 --- a/fast_td3/fast_td3_utils.py +++ b/fast_td3/fast_td3_utils.py @@ -23,7 +23,7 @@ class SimpleReplayBuffer(nn.Module): """ A simple replay buffer that stores transitions in a circular buffer. Supports n-step returns and asymmetric observations. - + When playground_mode=True, critic_observations are treated as a concatenation of regular observations and privileged observations, and only the privileged part is stored to save memory. @@ -62,10 +62,14 @@ class SimpleReplayBuffer(nn.Module): # Only store the privileged part of observations (n_critic_obs - n_obs) self.privileged_obs_size = n_critic_obs - n_obs self.privileged_observations = torch.zeros( - (n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float + (n_env, buffer_size, self.privileged_obs_size), + device=device, + dtype=torch.float, ) self.next_privileged_observations = torch.zeros( - (n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float + (n_env, buffer_size, self.privileged_obs_size), + device=device, + dtype=torch.float, ) else: # Store full critic observations @@ -98,11 +102,11 @@ class SimpleReplayBuffer(nn.Module): if self.asymmetric_obs: critic_observations = tensor_dict["critic_observations"] next_critic_observations = tensor_dict["next"]["critic_observations"] - + if self.playground_mode: # Extract and store only the privileged part - privileged_observations = critic_observations[:, self.n_obs:] - next_privileged_observations = next_critic_observations[:, self.n_obs:] + privileged_observations = critic_observations[:, self.n_obs :] + next_privileged_observations = next_critic_observations[:, self.n_obs :] self.privileged_observations[:, ptr] = privileged_observations self.next_privileged_observations[:, ptr] = next_privileged_observations else: @@ -145,17 +149,23 @@ class SimpleReplayBuffer(nn.Module): if self.asymmetric_obs: if self.playground_mode: # Gather privileged observations - priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size) + priv_obs_indices = indices.unsqueeze(-1).expand( + -1, -1, self.privileged_obs_size + ) privileged_observations = torch.gather( self.privileged_observations, 1, priv_obs_indices ).reshape(self.n_env * batch_size, self.privileged_obs_size) next_privileged_observations = torch.gather( self.next_privileged_observations, 1, priv_obs_indices ).reshape(self.n_env * batch_size, self.privileged_obs_size) - + # Concatenate with regular observations to form full critic observations - critic_observations = torch.cat([observations, privileged_observations], dim=1) - next_critic_observations = torch.cat([next_observations, next_privileged_observations], dim=1) + critic_observations = torch.cat( + [observations, privileged_observations], dim=1 + ) + next_critic_observations = torch.cat( + [next_observations, next_privileged_observations], dim=1 + ) else: # Gather full critic observations critic_obs_indices = indices.unsqueeze(-1).expand( @@ -188,13 +198,17 @@ class SimpleReplayBuffer(nn.Module): if self.asymmetric_obs: if self.playground_mode: # Gather privileged observations - priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size) + priv_obs_indices = indices.unsqueeze(-1).expand( + -1, -1, self.privileged_obs_size + ) privileged_observations = torch.gather( self.privileged_observations, 1, priv_obs_indices ).reshape(self.n_env * batch_size, self.privileged_obs_size) - + # Concatenate with regular observations to form full critic observations - critic_observations = torch.cat([observations, privileged_observations], dim=1) + critic_observations = torch.cat( + [observations, privileged_observations], dim=1 + ) else: # Gather full critic observations critic_obs_indices = indices.unsqueeze(-1).expand( @@ -283,37 +297,44 @@ class SimpleReplayBuffer(nn.Module): if self.asymmetric_obs: if self.playground_mode: # Gather final privileged observations - final_next_privileged_observations = self.next_privileged_observations.gather( - 1, - final_next_obs_indices.unsqueeze(-1).expand( - -1, -1, self.privileged_obs_size - ), + final_next_privileged_observations = ( + self.next_privileged_observations.gather( + 1, + final_next_obs_indices.unsqueeze(-1).expand( + -1, -1, self.privileged_obs_size + ), + ) ) - + # Reshape for output - next_privileged_observations = final_next_privileged_observations.reshape( - self.n_env * batch_size, self.privileged_obs_size + next_privileged_observations = ( + final_next_privileged_observations.reshape( + self.n_env * batch_size, self.privileged_obs_size + ) ) - + # Concatenate with next observations to form full next critic observations next_observations_reshaped = final_next_observations.reshape( self.n_env * batch_size, self.n_obs ) next_critic_observations = torch.cat( - [next_observations_reshaped, next_privileged_observations], dim=1 + [next_observations_reshaped, next_privileged_observations], + dim=1, ) else: # Gather final next critic observations directly - final_next_critic_observations = self.next_critic_observations.gather( - 1, - final_next_obs_indices.unsqueeze(-1).expand( - -1, -1, self.n_critic_obs - ), + final_next_critic_observations = ( + self.next_critic_observations.gather( + 1, + final_next_obs_indices.unsqueeze(-1).expand( + -1, -1, self.n_critic_obs + ), + ) ) next_critic_observations = final_next_critic_observations.reshape( self.n_env * batch_size, self.n_critic_obs ) - + # Reshape everything to batch dimension rewards = n_step_rewards.reshape(self.n_env * batch_size) dones = final_dones.reshape(self.n_env * batch_size)