Merge pull request #3 from younggyoseo/minor_updates_dev1
Update tuned_reward for T1
This commit is contained in:
commit
3f22046fa8
@ -77,9 +77,17 @@ def make_env(
|
|||||||
):
|
):
|
||||||
# Make training environment
|
# Make training environment
|
||||||
train_env_cfg = registry.get_default_config(env_name)
|
train_env_cfg = registry.get_default_config(env_name)
|
||||||
if use_tuned_reward:
|
is_humanoid_task = env_name in [
|
||||||
|
"G1JoystickRoughTerrain",
|
||||||
|
"G1JoystickFlatTerrain",
|
||||||
|
"T1JoystickRoughTerrain",
|
||||||
|
"T1JoystickFlatTerrain",
|
||||||
|
]
|
||||||
|
|
||||||
|
if use_tuned_reward and is_humanoid_task:
|
||||||
# NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper.
|
# NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper.
|
||||||
assert env_name in ["G1JoystickRoughTerrain", "G1JoystickFlatTerrain"]
|
# Somehow it works reasonably for T1 as well.
|
||||||
|
# However, see `sim2real.md` for sim-to-real RL with Booster T1
|
||||||
train_env_cfg.reward_config.scales.energy = -5e-5
|
train_env_cfg.reward_config.scales.energy = -5e-5
|
||||||
train_env_cfg.reward_config.scales.action_rate = -1e-1
|
train_env_cfg.reward_config.scales.action_rate = -1e-1
|
||||||
train_env_cfg.reward_config.scales.torques = -1e-3
|
train_env_cfg.reward_config.scales.torques = -1e-3
|
||||||
@ -90,13 +98,6 @@ def make_env(
|
|||||||
train_env_cfg.reward_config.scales.ang_vel_xy = -0.3
|
train_env_cfg.reward_config.scales.ang_vel_xy = -0.3
|
||||||
train_env_cfg.reward_config.scales.orientation = -5.0
|
train_env_cfg.reward_config.scales.orientation = -5.0
|
||||||
|
|
||||||
is_humanoid_task = env_name in [
|
|
||||||
"G1JoystickRoughTerrain",
|
|
||||||
"G1JoystickFlatTerrain",
|
|
||||||
"T1JoystickRoughTerrain",
|
|
||||||
"T1JoystickFlatTerrain",
|
|
||||||
]
|
|
||||||
|
|
||||||
if is_humanoid_task and not use_push_randomization:
|
if is_humanoid_task and not use_push_randomization:
|
||||||
train_env_cfg.push_config.enable = False
|
train_env_cfg.push_config.enable = False
|
||||||
train_env_cfg.push_config.magnitude_range = [0.0, 0.0]
|
train_env_cfg.push_config.magnitude_range = [0.0, 0.0]
|
||||||
|
@ -62,10 +62,14 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
# Only store the privileged part of observations (n_critic_obs - n_obs)
|
# Only store the privileged part of observations (n_critic_obs - n_obs)
|
||||||
self.privileged_obs_size = n_critic_obs - n_obs
|
self.privileged_obs_size = n_critic_obs - n_obs
|
||||||
self.privileged_observations = torch.zeros(
|
self.privileged_observations = torch.zeros(
|
||||||
(n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float
|
(n_env, buffer_size, self.privileged_obs_size),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float,
|
||||||
)
|
)
|
||||||
self.next_privileged_observations = torch.zeros(
|
self.next_privileged_observations = torch.zeros(
|
||||||
(n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float
|
(n_env, buffer_size, self.privileged_obs_size),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Store full critic observations
|
# Store full critic observations
|
||||||
@ -101,8 +105,8 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
|
|
||||||
if self.playground_mode:
|
if self.playground_mode:
|
||||||
# Extract and store only the privileged part
|
# Extract and store only the privileged part
|
||||||
privileged_observations = critic_observations[:, self.n_obs:]
|
privileged_observations = critic_observations[:, self.n_obs :]
|
||||||
next_privileged_observations = next_critic_observations[:, self.n_obs:]
|
next_privileged_observations = next_critic_observations[:, self.n_obs :]
|
||||||
self.privileged_observations[:, ptr] = privileged_observations
|
self.privileged_observations[:, ptr] = privileged_observations
|
||||||
self.next_privileged_observations[:, ptr] = next_privileged_observations
|
self.next_privileged_observations[:, ptr] = next_privileged_observations
|
||||||
else:
|
else:
|
||||||
@ -145,7 +149,9 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
if self.asymmetric_obs:
|
if self.asymmetric_obs:
|
||||||
if self.playground_mode:
|
if self.playground_mode:
|
||||||
# Gather privileged observations
|
# Gather privileged observations
|
||||||
priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size)
|
priv_obs_indices = indices.unsqueeze(-1).expand(
|
||||||
|
-1, -1, self.privileged_obs_size
|
||||||
|
)
|
||||||
privileged_observations = torch.gather(
|
privileged_observations = torch.gather(
|
||||||
self.privileged_observations, 1, priv_obs_indices
|
self.privileged_observations, 1, priv_obs_indices
|
||||||
).reshape(self.n_env * batch_size, self.privileged_obs_size)
|
).reshape(self.n_env * batch_size, self.privileged_obs_size)
|
||||||
@ -154,8 +160,12 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
).reshape(self.n_env * batch_size, self.privileged_obs_size)
|
).reshape(self.n_env * batch_size, self.privileged_obs_size)
|
||||||
|
|
||||||
# Concatenate with regular observations to form full critic observations
|
# Concatenate with regular observations to form full critic observations
|
||||||
critic_observations = torch.cat([observations, privileged_observations], dim=1)
|
critic_observations = torch.cat(
|
||||||
next_critic_observations = torch.cat([next_observations, next_privileged_observations], dim=1)
|
[observations, privileged_observations], dim=1
|
||||||
|
)
|
||||||
|
next_critic_observations = torch.cat(
|
||||||
|
[next_observations, next_privileged_observations], dim=1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Gather full critic observations
|
# Gather full critic observations
|
||||||
critic_obs_indices = indices.unsqueeze(-1).expand(
|
critic_obs_indices = indices.unsqueeze(-1).expand(
|
||||||
@ -188,13 +198,17 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
if self.asymmetric_obs:
|
if self.asymmetric_obs:
|
||||||
if self.playground_mode:
|
if self.playground_mode:
|
||||||
# Gather privileged observations
|
# Gather privileged observations
|
||||||
priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size)
|
priv_obs_indices = indices.unsqueeze(-1).expand(
|
||||||
|
-1, -1, self.privileged_obs_size
|
||||||
|
)
|
||||||
privileged_observations = torch.gather(
|
privileged_observations = torch.gather(
|
||||||
self.privileged_observations, 1, priv_obs_indices
|
self.privileged_observations, 1, priv_obs_indices
|
||||||
).reshape(self.n_env * batch_size, self.privileged_obs_size)
|
).reshape(self.n_env * batch_size, self.privileged_obs_size)
|
||||||
|
|
||||||
# Concatenate with regular observations to form full critic observations
|
# Concatenate with regular observations to form full critic observations
|
||||||
critic_observations = torch.cat([observations, privileged_observations], dim=1)
|
critic_observations = torch.cat(
|
||||||
|
[observations, privileged_observations], dim=1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Gather full critic observations
|
# Gather full critic observations
|
||||||
critic_obs_indices = indices.unsqueeze(-1).expand(
|
critic_obs_indices = indices.unsqueeze(-1).expand(
|
||||||
@ -283,33 +297,40 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
if self.asymmetric_obs:
|
if self.asymmetric_obs:
|
||||||
if self.playground_mode:
|
if self.playground_mode:
|
||||||
# Gather final privileged observations
|
# Gather final privileged observations
|
||||||
final_next_privileged_observations = self.next_privileged_observations.gather(
|
final_next_privileged_observations = (
|
||||||
|
self.next_privileged_observations.gather(
|
||||||
1,
|
1,
|
||||||
final_next_obs_indices.unsqueeze(-1).expand(
|
final_next_obs_indices.unsqueeze(-1).expand(
|
||||||
-1, -1, self.privileged_obs_size
|
-1, -1, self.privileged_obs_size
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Reshape for output
|
# Reshape for output
|
||||||
next_privileged_observations = final_next_privileged_observations.reshape(
|
next_privileged_observations = (
|
||||||
|
final_next_privileged_observations.reshape(
|
||||||
self.n_env * batch_size, self.privileged_obs_size
|
self.n_env * batch_size, self.privileged_obs_size
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Concatenate with next observations to form full next critic observations
|
# Concatenate with next observations to form full next critic observations
|
||||||
next_observations_reshaped = final_next_observations.reshape(
|
next_observations_reshaped = final_next_observations.reshape(
|
||||||
self.n_env * batch_size, self.n_obs
|
self.n_env * batch_size, self.n_obs
|
||||||
)
|
)
|
||||||
next_critic_observations = torch.cat(
|
next_critic_observations = torch.cat(
|
||||||
[next_observations_reshaped, next_privileged_observations], dim=1
|
[next_observations_reshaped, next_privileged_observations],
|
||||||
|
dim=1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Gather final next critic observations directly
|
# Gather final next critic observations directly
|
||||||
final_next_critic_observations = self.next_critic_observations.gather(
|
final_next_critic_observations = (
|
||||||
|
self.next_critic_observations.gather(
|
||||||
1,
|
1,
|
||||||
final_next_obs_indices.unsqueeze(-1).expand(
|
final_next_obs_indices.unsqueeze(-1).expand(
|
||||||
-1, -1, self.n_critic_obs
|
-1, -1, self.n_critic_obs
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
next_critic_observations = final_next_critic_observations.reshape(
|
next_critic_observations = final_next_critic_observations.reshape(
|
||||||
self.n_env * batch_size, self.n_critic_obs
|
self.n_env * batch_size, self.n_critic_obs
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user