allow history observation

This commit is contained in:
allenzren 2024-09-11 21:44:47 -04:00
parent 2ddf63b8f5
commit f13eb203e1
26 changed files with 64 additions and 23 deletions

View File

@ -162,7 +162,10 @@ To use DDIM fine-tuning, set `denoising_steps=100` in pre-training and set `mode
Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel) and `traj_length`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_length` is a 1-D array for indexing across num_total_steps. Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel) and `traj_length`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_length` is a 1-D array for indexing across num_total_steps.
<!-- One pre-processing example can be found at [`script/process_robomimic_dataset.py`](script/process_robomimic_dataset.py). --> <!-- One pre-processing example can be found at [`script/process_robomimic_dataset.py`](script/process_robomimic_dataset.py). -->
**Note:** The current implementation does not support loading history observations (only using observation at the current timestep). If needed, you can modify [here](agent/dataset/sequence.py#L130-L131). <!-- **Note:** The current implementation does not support loading history observations (only using observation at the current timestep). If needed, you can modify [here](agent/dataset/sequence.py#L130-L131). -->
#### Observation history
In our experiments we did not use any observation from previous timesteps (state or pixel), but it is implemented. You can set `cond_steps=<num_state_obs_step>` (and `img_cond_steps=<num_img_obs_step>`, no larger than `cond_steps`) in pre-training, and set the same when fine-tuning the newly pre-trained policy.
### Fine-tuning environment ### Fine-tuning environment
We follow the Gym format for interacting with the environments. The vectorized environments are initialized at [make_async](env/gym_utils/__init__.py#L10) (called in the parent fine-tuning agent class [here](agent/finetune/train_agent.py#L38-L39)). The current implementation is not the cleanest as we tried to make it compatible with Gym, Robomimic, Furniture-Bench, and D3IL environments, but it should be easy to modify and allow using other environments. We use [multi_step](env/gym_utils/wrapper/multi_step.py) wrapper for history observations (not used currently) and multi-environment-step action execution. We also use environment-specific wrappers such as [robomimic_lowdim](env/gym_utils/wrapper/robomimic_lowdim.py) and [furniture](env/gym_utils/wrapper/furniture.py) for observation/action normalization, etc. You can implement a new environment wrapper if needed. We follow the Gym format for interacting with the environments. The vectorized environments are initialized at [make_async](env/gym_utils/__init__.py#L10) (called in the parent fine-tuning agent class [here](agent/finetune/train_agent.py#L38-L39)). The current implementation is not the cleanest as we tried to make it compatible with Gym, Robomimic, Furniture-Bench, and D3IL environments, but it should be easy to modify and allow using other environments. We use [multi_step](env/gym_utils/wrapper/multi_step.py) wrapper for history observations (not used currently) and multi-environment-step action execution. We also use environment-specific wrappers such as [robomimic_lowdim](env/gym_utils/wrapper/robomimic_lowdim.py) and [furniture](env/gym_utils/wrapper/furniture.py) for observation/action normalization, etc. You can implement a new environment wrapper if needed.

View File

@ -1,8 +1,6 @@
""" """
Pre-training data loader. Modified from https://github.com/jannerm/diffuser/blob/main/diffuser/datasets/sequence.py Pre-training data loader. Modified from https://github.com/jannerm/diffuser/blob/main/diffuser/datasets/sequence.py
TODO: implement history observation
No normalization is applied here --- we always normalize the data when pre-processing it with a different script, and the normalization info is also used in RL fine-tuning. No normalization is applied here --- we always normalize the data when pre-processing it with a different script, and the normalization info is also used in RL fine-tuning.
""" """

View File

@ -34,6 +34,7 @@ env:
furniture: lamp furniture: lamp
randomness: low randomness: low
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -34,6 +34,7 @@ env:
furniture: lamp furniture: lamp
randomness: low randomness: low
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -31,6 +31,7 @@ env:
furniture: lamp furniture: lamp
randomness: low randomness: low
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -34,6 +34,7 @@ env:
furniture: lamp furniture: lamp
randomness: med randomness: med
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -34,6 +34,7 @@ env:
furniture: lamp furniture: lamp
randomness: med randomness: med
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -31,6 +31,7 @@ env:
furniture: lamp furniture: lamp
randomness: med randomness: med
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -34,6 +34,7 @@ env:
furniture: one_leg furniture: one_leg
randomness: low randomness: low
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -34,6 +34,7 @@ env:
furniture: one_leg furniture: one_leg
randomness: low randomness: low
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -31,6 +31,7 @@ env:
furniture: one_leg furniture: one_leg
randomness: low randomness: low
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -34,6 +34,7 @@ env:
furniture: one_leg furniture: one_leg
randomness: med randomness: med
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -34,6 +34,7 @@ env:
furniture: one_leg furniture: one_leg
randomness: med randomness: med
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -31,6 +31,7 @@ env:
furniture: one_leg furniture: one_leg
randomness: med randomness: med
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -34,6 +34,7 @@ env:
furniture: round_table furniture: round_table
randomness: low randomness: low
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -34,6 +34,7 @@ env:
furniture: round_table furniture: round_table
randomness: low randomness: low
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -31,6 +31,7 @@ env:
furniture: round_table furniture: round_table
randomness: low randomness: low
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -34,6 +34,7 @@ env:
furniture: round_table furniture: round_table
randomness: med randomness: med
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -34,6 +34,7 @@ env:
furniture: round_table furniture: round_table
randomness: med randomness: med
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -31,6 +31,7 @@ env:
furniture: round_table furniture: round_table
randomness: med randomness: med
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps} act_steps: ${act_steps}
sparse_reward: True sparse_reward: True

View File

@ -24,6 +24,7 @@ def make_async(
normalization_path=None, normalization_path=None,
furniture="one_leg", furniture="one_leg",
randomness="low", randomness="low",
obs_steps=1,
act_steps=8, act_steps=8,
sparse_reward=False, sparse_reward=False,
# below for robomimic only # below for robomimic only
@ -93,19 +94,16 @@ def make_async(
stiffness=1_000, stiffness=1_000,
damping=200, damping=200,
) )
env = FurnitureRLSimEnvMultiStepWrapper( env = FurnitureRLSimEnvMultiStepWrapper(
env, env,
n_obs_steps=1, n_obs_steps=obs_steps,
n_action_steps=act_steps, n_action_steps=act_steps,
reward_agg_method="sum",
prev_action=False, prev_action=False,
reset_within_step=False, reset_within_step=False,
pass_full_observations=False, pass_full_observations=False,
normalization_path=normalization_path, normalization_path=normalization_path,
sparse_reward=sparse_reward, sparse_reward=sparse_reward,
) )
return env return env
# avoid import error due incompatible gym versions # avoid import error due incompatible gym versions

View File

@ -5,8 +5,10 @@ Environment wrapper for Furniture-Bench environments.
import gym import gym
import numpy as np import numpy as np
from furniture_bench.envs.furniture_rl_sim_env import FurnitureRLSimEnv
import torch import torch
from collections import deque
from furniture_bench.envs.furniture_rl_sim_env import FurnitureRLSimEnv
from furniture_bench.controllers.control_utils import proprioceptive_quat_to_6d_rotation from furniture_bench.controllers.control_utils import proprioceptive_quat_to_6d_rotation
from ..furniture_normalizer import LinearNormalizer from ..furniture_normalizer import LinearNormalizer
from .multi_step import repeated_space from .multi_step import repeated_space
@ -16,6 +18,32 @@ import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def stack_last_n_obs_dict(all_obs, n_steps):
"""Apply padding"""
assert len(all_obs) > 0
all_obs = list(all_obs)
result = {
key: torch.zeros(
list(all_obs[-1][key].shape)[0:1]
+ [n_steps]
+ list(all_obs[-1][key].shape)[1:],
dtype=all_obs[-1][key].dtype,
).to(
all_obs[-1][key].device
) # add step dimension
for key in all_obs[-1]
}
start_idx = -min(n_steps, len(all_obs))
for key in all_obs[-1]:
result[key][:, start_idx:] = torch.concatenate(
[obs[key][:, None] for obs in all_obs[start_idx:]], dim=1
) # add step dimension
if n_steps > len(all_obs):
# pad
result[key][:start_idx] = result[key][start_idx]
return result
class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper): class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
env: FurnitureRLSimEnv env: FurnitureRLSimEnv
@ -26,7 +54,6 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
n_action_steps=1, n_action_steps=1,
max_episode_steps=None, max_episode_steps=None,
sparse_reward=False, sparse_reward=False,
reward_agg_method="sum", # never use other types
reset_within_step=False, reset_within_step=False,
pass_full_observations=False, pass_full_observations=False,
normalization_path=None, normalization_path=None,
@ -35,8 +62,6 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
assert ( assert (
not reset_within_step not reset_within_step
), "reset_within_step must be False for furniture envs" ), "reset_within_step must be False for furniture envs"
assert n_obs_steps == 1, "n_obs_steps must be 1"
assert reward_agg_method == "sum", "reward_agg_method must be sum"
assert ( assert (
not pass_full_observations not pass_full_observations
), "pass_full_observations is not implemented yet" ), "pass_full_observations is not implemented yet"
@ -68,6 +93,8 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
): ):
"""Resets the environment.""" """Resets the environment."""
obs = self.env.reset() obs = self.env.reset()
self.obs = deque([obs], maxlen=max(self.n_obs_steps + 1, self.n_action_steps))
obs = stack_last_n_obs_dict(self.obs, self.n_obs_steps)
nobs = self.process_obs(obs) nobs = self.process_obs(obs)
self.best_reward = torch.zeros(self.env.num_envs).to(self.device) self.best_reward = torch.zeros(self.env.num_envs).to(self.device)
self.done = list() self.done = list()
@ -118,6 +145,7 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
for i in range(self.n_action_steps): for i in range(self.n_action_steps):
# The dimensions of the action_chunk are (num_envs, chunk_size, action_dim) # The dimensions of the action_chunk are (num_envs, chunk_size, action_dim)
obs, reward, done, info = self.env.step(action_chunk[:, i, :]) obs, reward, done, info = self.env.step(action_chunk[:, i, :])
self.obs.append(obs)
# track raw reward # track raw reward
sparse_reward += reward.squeeze() sparse_reward += reward.squeeze()
@ -130,21 +158,17 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
dones = dones | done.squeeze() dones = dones | done.squeeze()
obs = stack_last_n_obs_dict(self.obs, self.n_obs_steps)
return obs, sparse_reward, dense_reward, dones, info return obs, sparse_reward, dense_reward, dones, info
def process_obs(self, obs: torch.Tensor) -> np.ndarray: def process_obs(self, obs: torch.Tensor) -> np.ndarray:
robot_state = obs["robot_state"]
# Convert the robot state to have 6D pose # Convert the robot state to have 6D pose
robot_state = obs["robot_state"]
robot_state = proprioceptive_quat_to_6d_rotation(robot_state) robot_state = proprioceptive_quat_to_6d_rotation(robot_state)
parts_poses = obs["parts_poses"] parts_poses = obs["parts_poses"]
obs = torch.cat([robot_state, parts_poses], dim=-1) obs = torch.cat([robot_state, parts_poses], dim=-1)
nobs = self.normalizer(obs, "observations", forward=True) nobs = self.normalizer(obs, "observations", forward=True)
nobs = torch.clamp(nobs, -5, 5) nobs = torch.clamp(nobs, -5, 5)
return nobs.cpu().numpy() # (n_envs, n_obs_steps, obs_dim)
# Insert a dummy dimension for the n_obs_steps (n_envs, obs_dim) -> (n_envs, n_obs_steps, obs_dim)
nobs = nobs.unsqueeze(1).cpu().numpy()
return nobs

View File

@ -246,7 +246,7 @@ def make_dataset(
# Save to np file # Save to np file
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz") save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl") save_val_path = os.path.join(save_dir, save_name_prefix + "val.npz")
with open(save_train_path, "wb") as f: with open(save_train_path, "wb") as f:
pickle.dump(out_train, f) pickle.dump(out_train, f)
with open(save_val_path, "wb") as f: with open(save_val_path, "wb") as f:

View File

@ -143,7 +143,7 @@ def make_dataset(env_name, save_dir, save_name_prefix, val_split, logger):
# Save to np file # Save to np file
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz") save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl") save_val_path = os.path.join(save_dir, save_name_prefix + "val.npz")
with open(save_train_path, "wb") as f: with open(save_train_path, "wb") as f:
pickle.dump(out_train, f) pickle.dump(out_train, f)
with open(save_val_path, "wb") as f: with open(save_val_path, "wb") as f:

View File

@ -193,7 +193,7 @@ def make_dataset(load_path, save_dir, save_name_prefix, env_type, val_split):
# Save to np file # Save to np file
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz") save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl") save_val_path = os.path.join(save_dir, save_name_prefix + "val.npz")
with open(save_train_path, "wb") as f: with open(save_train_path, "wb") as f:
pickle.dump(out_train, f) pickle.dump(out_train, f)
with open(save_val_path, "wb") as f: with open(save_val_path, "wb") as f:

View File

@ -304,7 +304,7 @@ def make_dataset(
# Save to np file # Save to np file
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz") save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl") save_val_path = os.path.join(save_dir, save_name_prefix + "val.npz")
with open(save_train_path, "wb") as f: with open(save_train_path, "wb") as f:
pickle.dump(out_train, f) pickle.dump(out_train, f)
with open(save_val_path, "wb") as f: with open(save_val_path, "wb") as f: