Merge pull request #3 from younggyoseo/minor_updates_dev1

Update tuned_reward for T1
This commit is contained in:
Younggyo Seo 2025-05-29 01:30:33 -07:00 committed by GitHub
commit 3f22046fa8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 38 deletions

View File

@ -77,9 +77,17 @@ def make_env(
): ):
# Make training environment # Make training environment
train_env_cfg = registry.get_default_config(env_name) train_env_cfg = registry.get_default_config(env_name)
if use_tuned_reward: is_humanoid_task = env_name in [
"G1JoystickRoughTerrain",
"G1JoystickFlatTerrain",
"T1JoystickRoughTerrain",
"T1JoystickFlatTerrain",
]
if use_tuned_reward and is_humanoid_task:
# NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper. # NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper.
assert env_name in ["G1JoystickRoughTerrain", "G1JoystickFlatTerrain"] # Somehow it works reasonably for T1 as well.
# However, see `sim2real.md` for sim-to-real RL with Booster T1
train_env_cfg.reward_config.scales.energy = -5e-5 train_env_cfg.reward_config.scales.energy = -5e-5
train_env_cfg.reward_config.scales.action_rate = -1e-1 train_env_cfg.reward_config.scales.action_rate = -1e-1
train_env_cfg.reward_config.scales.torques = -1e-3 train_env_cfg.reward_config.scales.torques = -1e-3
@ -90,13 +98,6 @@ def make_env(
train_env_cfg.reward_config.scales.ang_vel_xy = -0.3 train_env_cfg.reward_config.scales.ang_vel_xy = -0.3
train_env_cfg.reward_config.scales.orientation = -5.0 train_env_cfg.reward_config.scales.orientation = -5.0
is_humanoid_task = env_name in [
"G1JoystickRoughTerrain",
"G1JoystickFlatTerrain",
"T1JoystickRoughTerrain",
"T1JoystickFlatTerrain",
]
if is_humanoid_task and not use_push_randomization: if is_humanoid_task and not use_push_randomization:
train_env_cfg.push_config.enable = False train_env_cfg.push_config.enable = False
train_env_cfg.push_config.magnitude_range = [0.0, 0.0] train_env_cfg.push_config.magnitude_range = [0.0, 0.0]

View File

@ -62,10 +62,14 @@ class SimpleReplayBuffer(nn.Module):
# Only store the privileged part of observations (n_critic_obs - n_obs) # Only store the privileged part of observations (n_critic_obs - n_obs)
self.privileged_obs_size = n_critic_obs - n_obs self.privileged_obs_size = n_critic_obs - n_obs
self.privileged_observations = torch.zeros( self.privileged_observations = torch.zeros(
(n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float (n_env, buffer_size, self.privileged_obs_size),
device=device,
dtype=torch.float,
) )
self.next_privileged_observations = torch.zeros( self.next_privileged_observations = torch.zeros(
(n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float (n_env, buffer_size, self.privileged_obs_size),
device=device,
dtype=torch.float,
) )
else: else:
# Store full critic observations # Store full critic observations
@ -101,8 +105,8 @@ class SimpleReplayBuffer(nn.Module):
if self.playground_mode: if self.playground_mode:
# Extract and store only the privileged part # Extract and store only the privileged part
privileged_observations = critic_observations[:, self.n_obs:] privileged_observations = critic_observations[:, self.n_obs :]
next_privileged_observations = next_critic_observations[:, self.n_obs:] next_privileged_observations = next_critic_observations[:, self.n_obs :]
self.privileged_observations[:, ptr] = privileged_observations self.privileged_observations[:, ptr] = privileged_observations
self.next_privileged_observations[:, ptr] = next_privileged_observations self.next_privileged_observations[:, ptr] = next_privileged_observations
else: else:
@ -145,7 +149,9 @@ class SimpleReplayBuffer(nn.Module):
if self.asymmetric_obs: if self.asymmetric_obs:
if self.playground_mode: if self.playground_mode:
# Gather privileged observations # Gather privileged observations
priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size) priv_obs_indices = indices.unsqueeze(-1).expand(
-1, -1, self.privileged_obs_size
)
privileged_observations = torch.gather( privileged_observations = torch.gather(
self.privileged_observations, 1, priv_obs_indices self.privileged_observations, 1, priv_obs_indices
).reshape(self.n_env * batch_size, self.privileged_obs_size) ).reshape(self.n_env * batch_size, self.privileged_obs_size)
@ -154,8 +160,12 @@ class SimpleReplayBuffer(nn.Module):
).reshape(self.n_env * batch_size, self.privileged_obs_size) ).reshape(self.n_env * batch_size, self.privileged_obs_size)
# Concatenate with regular observations to form full critic observations # Concatenate with regular observations to form full critic observations
critic_observations = torch.cat([observations, privileged_observations], dim=1) critic_observations = torch.cat(
next_critic_observations = torch.cat([next_observations, next_privileged_observations], dim=1) [observations, privileged_observations], dim=1
)
next_critic_observations = torch.cat(
[next_observations, next_privileged_observations], dim=1
)
else: else:
# Gather full critic observations # Gather full critic observations
critic_obs_indices = indices.unsqueeze(-1).expand( critic_obs_indices = indices.unsqueeze(-1).expand(
@ -188,13 +198,17 @@ class SimpleReplayBuffer(nn.Module):
if self.asymmetric_obs: if self.asymmetric_obs:
if self.playground_mode: if self.playground_mode:
# Gather privileged observations # Gather privileged observations
priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size) priv_obs_indices = indices.unsqueeze(-1).expand(
-1, -1, self.privileged_obs_size
)
privileged_observations = torch.gather( privileged_observations = torch.gather(
self.privileged_observations, 1, priv_obs_indices self.privileged_observations, 1, priv_obs_indices
).reshape(self.n_env * batch_size, self.privileged_obs_size) ).reshape(self.n_env * batch_size, self.privileged_obs_size)
# Concatenate with regular observations to form full critic observations # Concatenate with regular observations to form full critic observations
critic_observations = torch.cat([observations, privileged_observations], dim=1) critic_observations = torch.cat(
[observations, privileged_observations], dim=1
)
else: else:
# Gather full critic observations # Gather full critic observations
critic_obs_indices = indices.unsqueeze(-1).expand( critic_obs_indices = indices.unsqueeze(-1).expand(
@ -283,33 +297,40 @@ class SimpleReplayBuffer(nn.Module):
if self.asymmetric_obs: if self.asymmetric_obs:
if self.playground_mode: if self.playground_mode:
# Gather final privileged observations # Gather final privileged observations
final_next_privileged_observations = self.next_privileged_observations.gather( final_next_privileged_observations = (
self.next_privileged_observations.gather(
1, 1,
final_next_obs_indices.unsqueeze(-1).expand( final_next_obs_indices.unsqueeze(-1).expand(
-1, -1, self.privileged_obs_size -1, -1, self.privileged_obs_size
), ),
) )
)
# Reshape for output # Reshape for output
next_privileged_observations = final_next_privileged_observations.reshape( next_privileged_observations = (
final_next_privileged_observations.reshape(
self.n_env * batch_size, self.privileged_obs_size self.n_env * batch_size, self.privileged_obs_size
) )
)
# Concatenate with next observations to form full next critic observations # Concatenate with next observations to form full next critic observations
next_observations_reshaped = final_next_observations.reshape( next_observations_reshaped = final_next_observations.reshape(
self.n_env * batch_size, self.n_obs self.n_env * batch_size, self.n_obs
) )
next_critic_observations = torch.cat( next_critic_observations = torch.cat(
[next_observations_reshaped, next_privileged_observations], dim=1 [next_observations_reshaped, next_privileged_observations],
dim=1,
) )
else: else:
# Gather final next critic observations directly # Gather final next critic observations directly
final_next_critic_observations = self.next_critic_observations.gather( final_next_critic_observations = (
self.next_critic_observations.gather(
1, 1,
final_next_obs_indices.unsqueeze(-1).expand( final_next_obs_indices.unsqueeze(-1).expand(
-1, -1, self.n_critic_obs -1, -1, self.n_critic_obs
), ),
) )
)
next_critic_observations = final_next_critic_observations.reshape( next_critic_observations = final_next_critic_observations.reshape(
self.n_env * batch_size, self.n_critic_obs self.n_env * batch_size, self.n_critic_obs
) )