Merge pull request #3 from younggyoseo/minor_updates_dev1

Update tuned_reward for T1
This commit is contained in:
Younggyo Seo 2025-05-29 01:30:33 -07:00 committed by GitHub
commit 3f22046fa8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 38 deletions

View File

@ -77,9 +77,17 @@ def make_env(
): ):
# Make training environment # Make training environment
train_env_cfg = registry.get_default_config(env_name) train_env_cfg = registry.get_default_config(env_name)
if use_tuned_reward: is_humanoid_task = env_name in [
"G1JoystickRoughTerrain",
"G1JoystickFlatTerrain",
"T1JoystickRoughTerrain",
"T1JoystickFlatTerrain",
]
if use_tuned_reward and is_humanoid_task:
# NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper. # NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper.
assert env_name in ["G1JoystickRoughTerrain", "G1JoystickFlatTerrain"] # Somehow it works reasonably for T1 as well.
# However, see `sim2real.md` for sim-to-real RL with Booster T1
train_env_cfg.reward_config.scales.energy = -5e-5 train_env_cfg.reward_config.scales.energy = -5e-5
train_env_cfg.reward_config.scales.action_rate = -1e-1 train_env_cfg.reward_config.scales.action_rate = -1e-1
train_env_cfg.reward_config.scales.torques = -1e-3 train_env_cfg.reward_config.scales.torques = -1e-3
@ -90,13 +98,6 @@ def make_env(
train_env_cfg.reward_config.scales.ang_vel_xy = -0.3 train_env_cfg.reward_config.scales.ang_vel_xy = -0.3
train_env_cfg.reward_config.scales.orientation = -5.0 train_env_cfg.reward_config.scales.orientation = -5.0
is_humanoid_task = env_name in [
"G1JoystickRoughTerrain",
"G1JoystickFlatTerrain",
"T1JoystickRoughTerrain",
"T1JoystickFlatTerrain",
]
if is_humanoid_task and not use_push_randomization: if is_humanoid_task and not use_push_randomization:
train_env_cfg.push_config.enable = False train_env_cfg.push_config.enable = False
train_env_cfg.push_config.magnitude_range = [0.0, 0.0] train_env_cfg.push_config.magnitude_range = [0.0, 0.0]

View File

@ -23,7 +23,7 @@ class SimpleReplayBuffer(nn.Module):
""" """
A simple replay buffer that stores transitions in a circular buffer. A simple replay buffer that stores transitions in a circular buffer.
Supports n-step returns and asymmetric observations. Supports n-step returns and asymmetric observations.
When playground_mode=True, critic_observations are treated as a concatenation of When playground_mode=True, critic_observations are treated as a concatenation of
regular observations and privileged observations, and only the privileged part is stored regular observations and privileged observations, and only the privileged part is stored
to save memory. to save memory.
@ -62,10 +62,14 @@ class SimpleReplayBuffer(nn.Module):
# Only store the privileged part of observations (n_critic_obs - n_obs) # Only store the privileged part of observations (n_critic_obs - n_obs)
self.privileged_obs_size = n_critic_obs - n_obs self.privileged_obs_size = n_critic_obs - n_obs
self.privileged_observations = torch.zeros( self.privileged_observations = torch.zeros(
(n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float (n_env, buffer_size, self.privileged_obs_size),
device=device,
dtype=torch.float,
) )
self.next_privileged_observations = torch.zeros( self.next_privileged_observations = torch.zeros(
(n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float (n_env, buffer_size, self.privileged_obs_size),
device=device,
dtype=torch.float,
) )
else: else:
# Store full critic observations # Store full critic observations
@ -98,11 +102,11 @@ class SimpleReplayBuffer(nn.Module):
if self.asymmetric_obs: if self.asymmetric_obs:
critic_observations = tensor_dict["critic_observations"] critic_observations = tensor_dict["critic_observations"]
next_critic_observations = tensor_dict["next"]["critic_observations"] next_critic_observations = tensor_dict["next"]["critic_observations"]
if self.playground_mode: if self.playground_mode:
# Extract and store only the privileged part # Extract and store only the privileged part
privileged_observations = critic_observations[:, self.n_obs:] privileged_observations = critic_observations[:, self.n_obs :]
next_privileged_observations = next_critic_observations[:, self.n_obs:] next_privileged_observations = next_critic_observations[:, self.n_obs :]
self.privileged_observations[:, ptr] = privileged_observations self.privileged_observations[:, ptr] = privileged_observations
self.next_privileged_observations[:, ptr] = next_privileged_observations self.next_privileged_observations[:, ptr] = next_privileged_observations
else: else:
@ -145,17 +149,23 @@ class SimpleReplayBuffer(nn.Module):
if self.asymmetric_obs: if self.asymmetric_obs:
if self.playground_mode: if self.playground_mode:
# Gather privileged observations # Gather privileged observations
priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size) priv_obs_indices = indices.unsqueeze(-1).expand(
-1, -1, self.privileged_obs_size
)
privileged_observations = torch.gather( privileged_observations = torch.gather(
self.privileged_observations, 1, priv_obs_indices self.privileged_observations, 1, priv_obs_indices
).reshape(self.n_env * batch_size, self.privileged_obs_size) ).reshape(self.n_env * batch_size, self.privileged_obs_size)
next_privileged_observations = torch.gather( next_privileged_observations = torch.gather(
self.next_privileged_observations, 1, priv_obs_indices self.next_privileged_observations, 1, priv_obs_indices
).reshape(self.n_env * batch_size, self.privileged_obs_size) ).reshape(self.n_env * batch_size, self.privileged_obs_size)
# Concatenate with regular observations to form full critic observations # Concatenate with regular observations to form full critic observations
critic_observations = torch.cat([observations, privileged_observations], dim=1) critic_observations = torch.cat(
next_critic_observations = torch.cat([next_observations, next_privileged_observations], dim=1) [observations, privileged_observations], dim=1
)
next_critic_observations = torch.cat(
[next_observations, next_privileged_observations], dim=1
)
else: else:
# Gather full critic observations # Gather full critic observations
critic_obs_indices = indices.unsqueeze(-1).expand( critic_obs_indices = indices.unsqueeze(-1).expand(
@ -188,13 +198,17 @@ class SimpleReplayBuffer(nn.Module):
if self.asymmetric_obs: if self.asymmetric_obs:
if self.playground_mode: if self.playground_mode:
# Gather privileged observations # Gather privileged observations
priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size) priv_obs_indices = indices.unsqueeze(-1).expand(
-1, -1, self.privileged_obs_size
)
privileged_observations = torch.gather( privileged_observations = torch.gather(
self.privileged_observations, 1, priv_obs_indices self.privileged_observations, 1, priv_obs_indices
).reshape(self.n_env * batch_size, self.privileged_obs_size) ).reshape(self.n_env * batch_size, self.privileged_obs_size)
# Concatenate with regular observations to form full critic observations # Concatenate with regular observations to form full critic observations
critic_observations = torch.cat([observations, privileged_observations], dim=1) critic_observations = torch.cat(
[observations, privileged_observations], dim=1
)
else: else:
# Gather full critic observations # Gather full critic observations
critic_obs_indices = indices.unsqueeze(-1).expand( critic_obs_indices = indices.unsqueeze(-1).expand(
@ -283,37 +297,44 @@ class SimpleReplayBuffer(nn.Module):
if self.asymmetric_obs: if self.asymmetric_obs:
if self.playground_mode: if self.playground_mode:
# Gather final privileged observations # Gather final privileged observations
final_next_privileged_observations = self.next_privileged_observations.gather( final_next_privileged_observations = (
1, self.next_privileged_observations.gather(
final_next_obs_indices.unsqueeze(-1).expand( 1,
-1, -1, self.privileged_obs_size final_next_obs_indices.unsqueeze(-1).expand(
), -1, -1, self.privileged_obs_size
),
)
) )
# Reshape for output # Reshape for output
next_privileged_observations = final_next_privileged_observations.reshape( next_privileged_observations = (
self.n_env * batch_size, self.privileged_obs_size final_next_privileged_observations.reshape(
self.n_env * batch_size, self.privileged_obs_size
)
) )
# Concatenate with next observations to form full next critic observations # Concatenate with next observations to form full next critic observations
next_observations_reshaped = final_next_observations.reshape( next_observations_reshaped = final_next_observations.reshape(
self.n_env * batch_size, self.n_obs self.n_env * batch_size, self.n_obs
) )
next_critic_observations = torch.cat( next_critic_observations = torch.cat(
[next_observations_reshaped, next_privileged_observations], dim=1 [next_observations_reshaped, next_privileged_observations],
dim=1,
) )
else: else:
# Gather final next critic observations directly # Gather final next critic observations directly
final_next_critic_observations = self.next_critic_observations.gather( final_next_critic_observations = (
1, self.next_critic_observations.gather(
final_next_obs_indices.unsqueeze(-1).expand( 1,
-1, -1, self.n_critic_obs final_next_obs_indices.unsqueeze(-1).expand(
), -1, -1, self.n_critic_obs
),
)
) )
next_critic_observations = final_next_critic_observations.reshape( next_critic_observations = final_next_critic_observations.reshape(
self.n_env * batch_size, self.n_critic_obs self.n_env * batch_size, self.n_critic_obs
) )
# Reshape everything to batch dimension # Reshape everything to batch dimension
rewards = n_step_rewards.reshape(self.n_env * batch_size) rewards = n_step_rewards.reshape(self.n_env * batch_size)
dones = final_dones.reshape(self.n_env * batch_size) dones = final_dones.reshape(self.n_env * batch_size)