* v0.5 (#9) * update idql configs * update awr configs * update dipo configs * update qsm configs * update dqm configs * update project version to 0.5.0
282 lines
8.7 KiB
Python
282 lines
8.7 KiB
Python
"""
|
|
Multi-step wrapper. Allow executing multiple environmnt steps. Returns stacked observation and optionally stacked previous action.
|
|
|
|
Modified from https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/gym_util/multistep_wrapper.py
|
|
|
|
TODO: allow cond_steps != img_cond_steps (should be implemented in training scripts, not here)
|
|
"""
|
|
|
|
import gym
|
|
from typing import Optional
|
|
from gym import spaces
|
|
import numpy as np
|
|
from collections import defaultdict, deque
|
|
|
|
|
|
def stack_repeated(x, n):
|
|
return np.repeat(np.expand_dims(x, axis=0), n, axis=0)
|
|
|
|
|
|
def repeated_box(box_space, n):
|
|
return spaces.Box(
|
|
low=stack_repeated(box_space.low, n),
|
|
high=stack_repeated(box_space.high, n),
|
|
shape=(n,) + box_space.shape,
|
|
dtype=box_space.dtype,
|
|
)
|
|
|
|
|
|
def repeated_space(space, n):
|
|
if isinstance(space, spaces.Box):
|
|
return repeated_box(space, n)
|
|
elif isinstance(space, spaces.Dict):
|
|
result_space = spaces.Dict()
|
|
for key, value in space.items():
|
|
result_space[key] = repeated_space(value, n)
|
|
return result_space
|
|
else:
|
|
raise RuntimeError(f"Unsupported space type {type(space)}")
|
|
|
|
|
|
def take_last_n(x, n):
|
|
x = list(x)
|
|
n = min(len(x), n)
|
|
return np.array(x[-n:])
|
|
|
|
|
|
def dict_take_last_n(x, n):
|
|
result = dict()
|
|
for key, value in x.items():
|
|
result[key] = take_last_n(value, n)
|
|
return result
|
|
|
|
|
|
def aggregate(data, method="max"):
|
|
if method == "max":
|
|
# equivalent to any
|
|
return np.max(data)
|
|
elif method == "min":
|
|
# equivalent to all
|
|
return np.min(data)
|
|
elif method == "mean":
|
|
return np.mean(data)
|
|
elif method == "sum":
|
|
return np.sum(data)
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
|
|
def stack_last_n_obs(all_obs, n_steps):
|
|
"""Apply padding"""
|
|
assert len(all_obs) > 0
|
|
all_obs = list(all_obs)
|
|
result = np.zeros((n_steps,) + all_obs[-1].shape, dtype=all_obs[-1].dtype)
|
|
start_idx = -min(n_steps, len(all_obs))
|
|
result[start_idx:] = np.array(all_obs[start_idx:])
|
|
if n_steps > len(all_obs):
|
|
# pad
|
|
result[:start_idx] = result[start_idx]
|
|
return result
|
|
|
|
|
|
class MultiStep(gym.Wrapper):
|
|
|
|
def __init__(
|
|
self,
|
|
env,
|
|
n_obs_steps=1,
|
|
n_action_steps=1,
|
|
max_episode_steps=None,
|
|
reward_agg_method="sum", # never use other types
|
|
prev_action=True,
|
|
reset_within_step=False,
|
|
pass_full_observations=False,
|
|
verbose=False,
|
|
**kwargs,
|
|
):
|
|
super().__init__(env)
|
|
self._single_action_space = env.action_space
|
|
self._action_space = repeated_space(env.action_space, n_action_steps)
|
|
self._observation_space = repeated_space(env.observation_space, n_obs_steps)
|
|
self.max_episode_steps = max_episode_steps
|
|
self.n_obs_steps = n_obs_steps
|
|
self.n_action_steps = n_action_steps
|
|
self.reward_agg_method = reward_agg_method
|
|
self.prev_action = prev_action
|
|
self.reset_within_step = reset_within_step
|
|
self.pass_full_observations = pass_full_observations
|
|
self.verbose = verbose
|
|
|
|
def reset(
|
|
self,
|
|
seed: Optional[int] = None,
|
|
return_info: bool = False,
|
|
options: dict = {},
|
|
):
|
|
"""Resets the environment."""
|
|
obs = self.env.reset(
|
|
seed=seed,
|
|
options=options,
|
|
return_info=return_info,
|
|
)
|
|
self.obs = deque([obs], maxlen=max(self.n_obs_steps + 1, self.n_action_steps))
|
|
if self.prev_action:
|
|
self.action = deque(
|
|
[self._single_action_space.sample()], maxlen=self.n_obs_steps
|
|
)
|
|
self.reward = list()
|
|
self.done = list()
|
|
self.info = defaultdict(lambda: deque(maxlen=self.n_obs_steps + 1))
|
|
obs = self._get_obs(self.n_obs_steps)
|
|
|
|
self.cnt = 0
|
|
return obs
|
|
|
|
def step(self, action):
|
|
"""
|
|
actions: (n_action_steps,) + action_shape
|
|
"""
|
|
if action.ndim == 1: # in case action_steps = 1
|
|
action = action[None]
|
|
truncated = False
|
|
terminated = False
|
|
for act_step, act in enumerate(action):
|
|
self.cnt += 1
|
|
if terminated or truncated:
|
|
break
|
|
|
|
# done does not differentiate terminal and truncation
|
|
observation, reward, done, info = self.env.step(act)
|
|
|
|
self.obs.append(observation)
|
|
self.action.append(act)
|
|
self.reward.append(reward)
|
|
|
|
# in gym, timelimit wrapper is automatically used given env._spec.max_episode_steps
|
|
if "TimeLimit.truncated" not in info:
|
|
if done:
|
|
terminated = True
|
|
elif (
|
|
self.max_episode_steps is not None
|
|
) and self.cnt >= self.max_episode_steps:
|
|
truncated = True
|
|
else:
|
|
truncated = info["TimeLimit.truncated"]
|
|
terminated = done
|
|
done = truncated or terminated
|
|
self.done.append(done)
|
|
self._add_info(info)
|
|
observation = self._get_obs(self.n_obs_steps)
|
|
reward = aggregate(self.reward, self.reward_agg_method)
|
|
done = aggregate(self.done, "max")
|
|
info = dict_take_last_n(self.info, self.n_obs_steps)
|
|
if self.pass_full_observations:
|
|
info["full_obs"] = self._get_obs(act_step + 1)
|
|
|
|
# In mujoco case, done can happen within the loop above
|
|
if self.reset_within_step and self.done[-1]:
|
|
|
|
# need to save old observation in the case of truncation only, for bootstrapping
|
|
if truncated:
|
|
info["final_obs"] = observation
|
|
|
|
# reset
|
|
observation = (
|
|
self.reset()
|
|
) # TODO: arguments? this cannot handle video recording right now since needs to pass in options
|
|
self.verbose and print("Reset env within wrapper.")
|
|
|
|
# reset reward and done for next step
|
|
self.reward = list()
|
|
self.done = list()
|
|
return observation, reward, terminated, truncated, info
|
|
|
|
def _get_obs(self, n_steps=1):
|
|
"""
|
|
Output (n_steps,) + obs_shape
|
|
"""
|
|
assert len(self.obs) > 0
|
|
if isinstance(self.observation_space, spaces.Box):
|
|
return stack_last_n_obs(self.obs, n_steps)
|
|
elif isinstance(self.observation_space, spaces.Dict):
|
|
result = dict()
|
|
for key in self.observation_space.keys():
|
|
result[key] = stack_last_n_obs([obs[key] for obs in self.obs], n_steps)
|
|
return result
|
|
else:
|
|
raise RuntimeError("Unsupported space type")
|
|
|
|
def get_prev_action(self, n_steps=None):
|
|
if n_steps is None:
|
|
n_steps = self.n_obs_steps - 1 # exclude current step
|
|
assert len(self.action) > 0
|
|
return stack_last_n_obs(self.action, n_steps)
|
|
|
|
def _add_info(self, info):
|
|
for key, value in info.items():
|
|
self.info[key].append(value)
|
|
|
|
def render(self, **kwargs):
|
|
"""Not the best design"""
|
|
return self.env.render(**kwargs)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import os
|
|
from omegaconf import OmegaConf
|
|
import json
|
|
|
|
os.environ["MUJOCO_GL"] = "egl"
|
|
|
|
cfg = OmegaConf.load("cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp_img.yaml")
|
|
shape_meta = cfg["shape_meta"]
|
|
|
|
import robomimic.utils.env_utils as EnvUtils
|
|
import robomimic.utils.obs_utils as ObsUtils
|
|
import matplotlib.pyplot as plt
|
|
from env.gym_utils.wrapper.robomimic_image import RobomimicImageWrapper
|
|
|
|
wrappers = cfg.env.wrappers
|
|
obs_modality_dict = {
|
|
"low_dim": (
|
|
wrappers.robomimic_image.low_dim_keys
|
|
if "robomimic_image" in wrappers
|
|
else wrappers.robomimic_lowdim.low_dim_keys
|
|
),
|
|
"rgb": (
|
|
wrappers.robomimic_image.image_keys
|
|
if "robomimic_image" in wrappers
|
|
else None
|
|
),
|
|
}
|
|
if obs_modality_dict["rgb"] is None:
|
|
obs_modality_dict.pop("rgb")
|
|
ObsUtils.initialize_obs_modality_mapping_from_dict(obs_modality_dict)
|
|
|
|
with open(cfg.robomimic_env_cfg_path, "r") as f:
|
|
env_meta = json.load(f)
|
|
env = EnvUtils.create_env_from_metadata(
|
|
env_meta=env_meta,
|
|
render=False,
|
|
render_offscreen=False,
|
|
use_image_obs=True,
|
|
)
|
|
env.env.hard_reset = False
|
|
|
|
wrapper = MultiStep(
|
|
env=RobomimicImageWrapper(
|
|
env=env,
|
|
shape_meta=shape_meta,
|
|
image_keys=["robot0_eye_in_hand_image"],
|
|
),
|
|
n_obs_steps=1,
|
|
n_action_steps=1,
|
|
)
|
|
wrapper.seed(0)
|
|
obs = wrapper.reset()
|
|
print(obs.keys())
|
|
img = wrapper.render()
|
|
wrapper.close()
|
|
plt.imshow(img)
|
|
plt.savefig("test.png")
|