""" Multi-step wrapper. Allow executing multiple environmnt steps. Returns stacked observation and optionally stacked previous action. Modified from https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/gym_util/multistep_wrapper.py """ import gym from typing import Optional from gym import spaces import numpy as np from collections import defaultdict, deque # import dill def stack_repeated(x, n): return np.repeat(np.expand_dims(x, axis=0), n, axis=0) def repeated_box(box_space, n): return spaces.Box( low=stack_repeated(box_space.low, n), high=stack_repeated(box_space.high, n), shape=(n,) + box_space.shape, dtype=box_space.dtype, ) def repeated_space(space, n): if isinstance(space, spaces.Box): return repeated_box(space, n) elif isinstance(space, spaces.Dict): result_space = spaces.Dict() for key, value in space.items(): result_space[key] = repeated_space(value, n) return result_space else: raise RuntimeError(f"Unsupported space type {type(space)}") def take_last_n(x, n): x = list(x) n = min(len(x), n) return np.array(x[-n:]) def dict_take_last_n(x, n): result = dict() for key, value in x.items(): result[key] = take_last_n(value, n) return result def aggregate(data, method="max"): if method == "max": # equivalent to any return np.max(data) elif method == "min": # equivalent to all return np.min(data) elif method == "mean": return np.mean(data) elif method == "sum": return np.sum(data) else: raise NotImplementedError() def stack_last_n_obs(all_obs, n_steps): """Apply padding""" assert len(all_obs) > 0 all_obs = list(all_obs) result = np.zeros((n_steps,) + all_obs[-1].shape, dtype=all_obs[-1].dtype) start_idx = -min(n_steps, len(all_obs)) result[start_idx:] = np.array(all_obs[start_idx:]) if n_steps > len(all_obs): # pad result[:start_idx] = result[start_idx] return result class MultiStep(gym.Wrapper): def __init__( self, env, n_obs_steps=1, n_action_steps=1, max_episode_steps=None, reward_agg_method="sum", # never use other types prev_action=True, reset_within_step=False, pass_full_observations=False, verbose=False, **kwargs, ): super().__init__(env) self._single_action_space = env.action_space self._action_space = repeated_space(env.action_space, n_action_steps) self._observation_space = repeated_space(env.observation_space, n_obs_steps) self.max_episode_steps = max_episode_steps self.n_obs_steps = n_obs_steps self.n_action_steps = n_action_steps self.reward_agg_method = reward_agg_method self.prev_action = prev_action self.reset_within_step = reset_within_step self.pass_full_observations = pass_full_observations self.verbose = verbose def reset( self, seed: Optional[int] = None, return_info: bool = False, options: dict = {}, ): """Resets the environment.""" obs = self.env.reset( seed=seed, options=options, return_info=return_info, ) self.obs = deque([obs], maxlen=max(self.n_obs_steps + 1, self.n_action_steps)) if self.prev_action: self.action = deque( [self._single_action_space.sample()], maxlen=self.n_obs_steps ) self.reward = list() self.done = list() self.info = defaultdict(lambda: deque(maxlen=self.n_obs_steps + 1)) obs = self._get_obs(self.n_obs_steps) self.cnt = 0 return obs def step(self, action): """ actions: (n_action_steps,) + action_shape """ if action.ndim == 1: # in case action_steps = 1 action = action[None] for act_step, act in enumerate(action): self.cnt += 1 if len(self.done) > 0 and self.done[-1]: # termination break observation, reward, done, info = self.env.step(act) self.obs.append(observation) self.action.append(act) self.reward.append(reward) if ( self.max_episode_steps is not None ) and self.cnt >= self.max_episode_steps: # truncation done = True self.done.append(done) self._add_info(info) observation = self._get_obs(self.n_obs_steps) reward = aggregate(self.reward, self.reward_agg_method) done = aggregate(self.done, "max") info = dict_take_last_n(self.info, self.n_obs_steps) if self.pass_full_observations: # right now this assume n_obs_steps = 1 info["full_obs"] = self._get_obs(act_step + 1) # In mujoco case, done can happen within the loop above if self.reset_within_step and self.done[-1]: observation = ( self.reset() ) # TODO: arguments? this cannot handle video recording right now since needs to pass in options self.verbose and print("Reset env within wrapper.") # reset reward and done for next step self.reward = list() self.done = list() return observation, reward, done, info def _get_obs(self, n_steps=1): """ Output (n_steps,) + obs_shape """ assert len(self.obs) > 0 if isinstance(self.observation_space, spaces.Box): return stack_last_n_obs(self.obs, n_steps) elif isinstance(self.observation_space, spaces.Dict): result = dict() for key in self.observation_space.keys(): result[key] = stack_last_n_obs([obs[key] for obs in self.obs], n_steps) return result else: raise RuntimeError("Unsupported space type") def get_prev_action(self, n_steps=None): if n_steps is None: n_steps = self.n_obs_steps - 1 # exclude current step assert len(self.action) > 0 return stack_last_n_obs(self.action, n_steps) def _add_info(self, info): for key, value in info.items(): self.info[key].append(value) def render(self, **kwargs): """Not the best design""" return self.env.render(**kwargs) # def get_rewards(self): # return self.reward # def get_attr(self, name): # return getattr(self, name) # def run_dill_function(self, dill_fn): # fn = dill.loads(dill_fn) # return fn(self) # def get_infos(self): # result = dict() # for k, v in self.info.items(): # result[k] = list(v) # return result if __name__ == "__main__": import os from omegaconf import OmegaConf import json os.environ["MUJOCO_GL"] = "egl" cfg = OmegaConf.load("cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp_img.yaml") shape_meta = cfg["shape_meta"] import robomimic.utils.env_utils as EnvUtils import robomimic.utils.obs_utils as ObsUtils import matplotlib.pyplot as plt from env.gym_utils.wrapper.robomimic_image import RobomimicImageWrapper wrappers = cfg.env.wrappers obs_modality_dict = { "low_dim": ( wrappers.robomimic_image.low_dim_keys if "robomimic_image" in wrappers else wrappers.robomimic_lowdim.low_dim_keys ), "rgb": ( wrappers.robomimic_image.image_keys if "robomimic_image" in wrappers else None ), } if obs_modality_dict["rgb"] is None: obs_modality_dict.pop("rgb") ObsUtils.initialize_obs_modality_mapping_from_dict(obs_modality_dict) with open(cfg.robomimic_env_cfg_path, "r") as f: env_meta = json.load(f) env = EnvUtils.create_env_from_metadata( env_meta=env_meta, render=False, render_offscreen=False, use_image_obs=True, ) env.env.hard_reset = False wrapper = MultiStep( env=RobomimicImageWrapper( env=env, shape_meta=shape_meta, image_keys=["robot0_eye_in_hand_image"], ), n_obs_steps=1, n_action_steps=1, ) wrapper.seed(0) obs = wrapper.reset() print(obs.keys()) img = wrapper.render() wrapper.close() plt.imshow(img) plt.savefig("test.png")