dppo/env/gym_utils/wrapper/robomimic_lowdim.py
2024-09-03 21:03:27 -04:00

143 lines
4.4 KiB
Python

"""
Environment wrapper for Robomimic environments with state observations.
Modified from https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/env/robomimic/robomimic_lowdim_wrapper.py
"""
import numpy as np
import gym
from gym.spaces import Box
import imageio
class RobomimicLowdimWrapper(gym.Env):
def __init__(
self,
env,
normalization_path=None,
low_dim_keys=[
"robot0_eef_pos",
"robot0_eef_quat",
"robot0_gripper_qpos",
"object",
],
clamp_obs=False,
init_state=None,
render_hw=(256, 256),
render_camera_name="agentview",
):
self.env = env
self.obs_keys = low_dim_keys
self.init_state = init_state
self.render_hw = render_hw
self.render_camera_name = render_camera_name
self.video_writer = None
self.clamp_obs = clamp_obs
# set up normalization
self.normalize = normalization_path is not None
if self.normalize:
normalization = np.load(normalization_path)
self.obs_min = normalization["obs_min"]
self.obs_max = normalization["obs_max"]
self.action_min = normalization["action_min"]
self.action_max = normalization["action_max"]
# setup spaces - use [-1, 1]
low = np.full(env.action_dimension, fill_value=-1)
high = np.full(env.action_dimension, fill_value=1)
self.action_space = Box(
low=low,
high=high,
shape=low.shape,
dtype=low.dtype,
)
obs_example = self.get_observation()
low = np.full_like(obs_example, fill_value=-1)
high = np.full_like(obs_example, fill_value=1)
self.observation_space = Box(
low=low,
high=high,
shape=low.shape,
dtype=low.dtype,
)
def normalize_obs(self, obs):
obs = 2 * (
(obs - self.obs_min) / (self.obs_max - self.obs_min + 1e-6) - 0.5
) # -> [-1, 1]
if self.clamp_obs:
obs = np.clip(obs, -1, 1)
return obs
def unnormalize_action(self, action):
action = (action + 1) / 2 # [-1, 1] -> [0, 1]
return action * (self.action_max - self.action_min) + self.action_min
def get_observation(self):
raw_obs = self.env.get_observation()
raw_obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0)
if self.normalize:
return self.normalize_obs(raw_obs)
return raw_obs
def seed(self, seed=None):
if seed is not None:
np.random.seed(seed=seed)
else:
np.random.seed()
def reset(self, options={}, **kwargs):
"""Ignore passed-in arguments like seed"""
# Close video if exists
if self.video_writer is not None:
self.video_writer.close()
self.video_writer = None
# Start video if specified
if "video_path" in options:
self.video_writer = imageio.get_writer(options["video_path"], fps=30)
# Call reset
new_seed = options.get(
"seed", None
) # used to set all environments to specified seeds
if self.init_state is not None:
# always reset to the same state to be compatible with gym
self.env.reset_to({"states": self.init_state})
elif new_seed is not None:
self.seed(seed=new_seed)
self.env.reset()
else:
# random reset
self.env.reset()
return self.get_observation()
def step(self, action):
if self.normalize:
action = self.unnormalize_action(action)
raw_obs, reward, done, info = self.env.step(action)
raw_obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0)
if self.normalize:
obs = self.normalize_obs(raw_obs)
else:
obs = raw_obs
# render if specified
if self.video_writer is not None:
video_img = self.render(mode="rgb_array")
self.video_writer.append_data(video_img)
return obs, reward, done, info
def render(self, mode="rgb_array"):
h, w = self.render_hw
return self.env.render(
mode=mode,
height=h,
width=w,
camera_name=self.render_camera_name,
)