memory optimization for playground
This commit is contained in:
parent
5db18c2de2
commit
5725eba3b8
@ -98,7 +98,8 @@ python fast_td3/train.py --env_name h1hand-hurdle-v0 --exp_name FastTD3 --render
|
|||||||
### MuJoCo Playground Experiments
|
### MuJoCo Playground Experiments
|
||||||
```bash
|
```bash
|
||||||
conda activate fasttd3_playground
|
conda activate fasttd3_playground
|
||||||
python fast_td3/train.py --env_name G1JoystickRoughTerrain --exp_name FastTD3 --render_interval 5000 --seed 1
|
python fast_td3/train.py --env_name T1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --seed 1
|
||||||
|
python fast_td3/train.py --env_name G1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --seed 1
|
||||||
```
|
```
|
||||||
|
|
||||||
### IsaacLab Experiments
|
### IsaacLab Experiments
|
||||||
|
@ -15,6 +15,7 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
n_act: int,
|
n_act: int,
|
||||||
n_critic_obs: int,
|
n_critic_obs: int,
|
||||||
asymmetric_obs: bool = False,
|
asymmetric_obs: bool = False,
|
||||||
|
playground_mode: bool = False,
|
||||||
n_steps: int = 1,
|
n_steps: int = 1,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
device=None,
|
device=None,
|
||||||
@ -22,6 +23,10 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
A simple replay buffer that stores transitions in a circular buffer.
|
A simple replay buffer that stores transitions in a circular buffer.
|
||||||
Supports n-step returns and asymmetric observations.
|
Supports n-step returns and asymmetric observations.
|
||||||
|
|
||||||
|
When playground_mode=True, critic_observations are treated as a concatenation of
|
||||||
|
regular observations and privileged observations, and only the privileged part is stored
|
||||||
|
to save memory.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -31,6 +36,7 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
self.n_act = n_act
|
self.n_act = n_act
|
||||||
self.n_critic_obs = n_critic_obs
|
self.n_critic_obs = n_critic_obs
|
||||||
self.asymmetric_obs = asymmetric_obs
|
self.asymmetric_obs = asymmetric_obs
|
||||||
|
self.playground_mode = playground_mode and asymmetric_obs
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.n_steps = n_steps
|
self.n_steps = n_steps
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -52,12 +58,23 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
(n_env, buffer_size, n_obs), device=device, dtype=torch.float
|
(n_env, buffer_size, n_obs), device=device, dtype=torch.float
|
||||||
)
|
)
|
||||||
if asymmetric_obs:
|
if asymmetric_obs:
|
||||||
self.critic_observations = torch.zeros(
|
if self.playground_mode:
|
||||||
(n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float
|
# Only store the privileged part of observations (n_critic_obs - n_obs)
|
||||||
)
|
self.privileged_obs_size = n_critic_obs - n_obs
|
||||||
self.next_critic_observations = torch.zeros(
|
self.privileged_observations = torch.zeros(
|
||||||
(n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float
|
(n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float
|
||||||
)
|
)
|
||||||
|
self.next_privileged_observations = torch.zeros(
|
||||||
|
(n_env, buffer_size, self.privileged_obs_size), device=device, dtype=torch.float
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Store full critic observations
|
||||||
|
self.critic_observations = torch.zeros(
|
||||||
|
(n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float
|
||||||
|
)
|
||||||
|
self.next_critic_observations = torch.zeros(
|
||||||
|
(n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float
|
||||||
|
)
|
||||||
self.ptr = 0
|
self.ptr = 0
|
||||||
|
|
||||||
def extend(
|
def extend(
|
||||||
@ -80,9 +97,18 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
self.next_observations[:, ptr] = next_observations
|
self.next_observations[:, ptr] = next_observations
|
||||||
if self.asymmetric_obs:
|
if self.asymmetric_obs:
|
||||||
critic_observations = tensor_dict["critic_observations"]
|
critic_observations = tensor_dict["critic_observations"]
|
||||||
self.critic_observations[:, ptr] = critic_observations
|
|
||||||
next_critic_observations = tensor_dict["next"]["critic_observations"]
|
next_critic_observations = tensor_dict["next"]["critic_observations"]
|
||||||
self.next_critic_observations[:, ptr] = next_critic_observations
|
|
||||||
|
if self.playground_mode:
|
||||||
|
# Extract and store only the privileged part
|
||||||
|
privileged_observations = critic_observations[:, self.n_obs:]
|
||||||
|
next_privileged_observations = next_critic_observations[:, self.n_obs:]
|
||||||
|
self.privileged_observations[:, ptr] = privileged_observations
|
||||||
|
self.next_privileged_observations[:, ptr] = next_privileged_observations
|
||||||
|
else:
|
||||||
|
# Store full critic observations
|
||||||
|
self.critic_observations[:, ptr] = critic_observations
|
||||||
|
self.next_critic_observations[:, ptr] = next_critic_observations
|
||||||
self.ptr += 1
|
self.ptr += 1
|
||||||
|
|
||||||
def sample(self, batch_size: int):
|
def sample(self, batch_size: int):
|
||||||
@ -117,15 +143,30 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
self.n_env * batch_size
|
self.n_env * batch_size
|
||||||
)
|
)
|
||||||
if self.asymmetric_obs:
|
if self.asymmetric_obs:
|
||||||
critic_obs_indices = indices.unsqueeze(-1).expand(
|
if self.playground_mode:
|
||||||
-1, -1, self.n_critic_obs
|
# Gather privileged observations
|
||||||
)
|
priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size)
|
||||||
critic_observations = torch.gather(
|
privileged_observations = torch.gather(
|
||||||
self.critic_observations, 1, critic_obs_indices
|
self.privileged_observations, 1, priv_obs_indices
|
||||||
).reshape(self.n_env * batch_size, self.n_critic_obs)
|
).reshape(self.n_env * batch_size, self.privileged_obs_size)
|
||||||
next_critic_observations = torch.gather(
|
next_privileged_observations = torch.gather(
|
||||||
self.next_critic_observations, 1, critic_obs_indices
|
self.next_privileged_observations, 1, priv_obs_indices
|
||||||
).reshape(self.n_env * batch_size, self.n_critic_obs)
|
).reshape(self.n_env * batch_size, self.privileged_obs_size)
|
||||||
|
|
||||||
|
# Concatenate with regular observations to form full critic observations
|
||||||
|
critic_observations = torch.cat([observations, privileged_observations], dim=1)
|
||||||
|
next_critic_observations = torch.cat([next_observations, next_privileged_observations], dim=1)
|
||||||
|
else:
|
||||||
|
# Gather full critic observations
|
||||||
|
critic_obs_indices = indices.unsqueeze(-1).expand(
|
||||||
|
-1, -1, self.n_critic_obs
|
||||||
|
)
|
||||||
|
critic_observations = torch.gather(
|
||||||
|
self.critic_observations, 1, critic_obs_indices
|
||||||
|
).reshape(self.n_env * batch_size, self.n_critic_obs)
|
||||||
|
next_critic_observations = torch.gather(
|
||||||
|
self.next_critic_observations, 1, critic_obs_indices
|
||||||
|
).reshape(self.n_env * batch_size, self.n_critic_obs)
|
||||||
else:
|
else:
|
||||||
# Sample base indices
|
# Sample base indices
|
||||||
indices = torch.randint(
|
indices = torch.randint(
|
||||||
@ -145,12 +186,23 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
self.n_env * batch_size, self.n_act
|
self.n_env * batch_size, self.n_act
|
||||||
)
|
)
|
||||||
if self.asymmetric_obs:
|
if self.asymmetric_obs:
|
||||||
critic_obs_indices = indices.unsqueeze(-1).expand(
|
if self.playground_mode:
|
||||||
-1, -1, self.n_critic_obs
|
# Gather privileged observations
|
||||||
)
|
priv_obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.privileged_obs_size)
|
||||||
critic_observations = torch.gather(
|
privileged_observations = torch.gather(
|
||||||
self.critic_observations, 1, critic_obs_indices
|
self.privileged_observations, 1, priv_obs_indices
|
||||||
).reshape(self.n_env * batch_size, self.n_critic_obs)
|
).reshape(self.n_env * batch_size, self.privileged_obs_size)
|
||||||
|
|
||||||
|
# Concatenate with regular observations to form full critic observations
|
||||||
|
critic_observations = torch.cat([observations, privileged_observations], dim=1)
|
||||||
|
else:
|
||||||
|
# Gather full critic observations
|
||||||
|
critic_obs_indices = indices.unsqueeze(-1).expand(
|
||||||
|
-1, -1, self.n_critic_obs
|
||||||
|
)
|
||||||
|
critic_observations = torch.gather(
|
||||||
|
self.critic_observations, 1, critic_obs_indices
|
||||||
|
).reshape(self.n_env * batch_size, self.n_critic_obs)
|
||||||
|
|
||||||
# Create sequential indices for each sample
|
# Create sequential indices for each sample
|
||||||
# This creates a [n_env, batch_size, n_step] tensor of indices
|
# This creates a [n_env, batch_size, n_step] tensor of indices
|
||||||
@ -229,14 +281,40 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
final_truncations = self.truncations.gather(1, final_next_obs_indices)
|
final_truncations = self.truncations.gather(1, final_next_obs_indices)
|
||||||
|
|
||||||
if self.asymmetric_obs:
|
if self.asymmetric_obs:
|
||||||
final_next_critic_observations = self.next_critic_observations.gather(
|
if self.playground_mode:
|
||||||
1,
|
# Gather final privileged observations
|
||||||
final_next_obs_indices.unsqueeze(-1).expand(
|
final_next_privileged_observations = self.next_privileged_observations.gather(
|
||||||
-1, -1, self.n_critic_obs
|
1,
|
||||||
),
|
final_next_obs_indices.unsqueeze(-1).expand(
|
||||||
)
|
-1, -1, self.privileged_obs_size
|
||||||
# Reshape everything to batch dimension
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape for output
|
||||||
|
next_privileged_observations = final_next_privileged_observations.reshape(
|
||||||
|
self.n_env * batch_size, self.privileged_obs_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concatenate with next observations to form full next critic observations
|
||||||
|
next_observations_reshaped = final_next_observations.reshape(
|
||||||
|
self.n_env * batch_size, self.n_obs
|
||||||
|
)
|
||||||
|
next_critic_observations = torch.cat(
|
||||||
|
[next_observations_reshaped, next_privileged_observations], dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Gather final next critic observations directly
|
||||||
|
final_next_critic_observations = self.next_critic_observations.gather(
|
||||||
|
1,
|
||||||
|
final_next_obs_indices.unsqueeze(-1).expand(
|
||||||
|
-1, -1, self.n_critic_obs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
next_critic_observations = final_next_critic_observations.reshape(
|
||||||
|
self.n_env * batch_size, self.n_critic_obs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape everything to batch dimension
|
||||||
rewards = n_step_rewards.reshape(self.n_env * batch_size)
|
rewards = n_step_rewards.reshape(self.n_env * batch_size)
|
||||||
dones = final_dones.reshape(self.n_env * batch_size)
|
dones = final_dones.reshape(self.n_env * batch_size)
|
||||||
truncations = final_truncations.reshape(self.n_env * batch_size)
|
truncations = final_truncations.reshape(self.n_env * batch_size)
|
||||||
@ -244,11 +322,6 @@ class SimpleReplayBuffer(nn.Module):
|
|||||||
self.n_env * batch_size, self.n_obs
|
self.n_env * batch_size, self.n_obs
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.asymmetric_obs:
|
|
||||||
next_critic_observations = final_next_critic_observations.reshape(
|
|
||||||
self.n_env * batch_size, self.n_critic_obs
|
|
||||||
)
|
|
||||||
|
|
||||||
out = TensorDict(
|
out = TensorDict(
|
||||||
{
|
{
|
||||||
"observations": observations,
|
"observations": observations,
|
||||||
|
@ -193,6 +193,7 @@ def main():
|
|||||||
n_act=n_act,
|
n_act=n_act,
|
||||||
n_critic_obs=n_critic_obs,
|
n_critic_obs=n_critic_obs,
|
||||||
asymmetric_obs=envs.asymmetric_obs,
|
asymmetric_obs=envs.asymmetric_obs,
|
||||||
|
playground_mode=env_type == "mujoco_playground",
|
||||||
n_steps=args.num_steps,
|
n_steps=args.num_steps,
|
||||||
gamma=args.gamma,
|
gamma=args.gamma,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -272,6 +272,7 @@
|
|||||||
" n_act=n_act,\n",
|
" n_act=n_act,\n",
|
||||||
" n_critic_obs=n_critic_obs,\n",
|
" n_critic_obs=n_critic_obs,\n",
|
||||||
" asymmetric_obs=envs.asymmetric_obs,\n",
|
" asymmetric_obs=envs.asymmetric_obs,\n",
|
||||||
|
" playground_mode=env_type == \"mujoco_playground\",\n",
|
||||||
" n_steps=args.num_steps,\n",
|
" n_steps=args.num_steps,\n",
|
||||||
" gamma=args.gamma,\n",
|
" gamma=args.gamma,\n",
|
||||||
" device=device,\n",
|
" device=device,\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user