88 lines
2.6 KiB
Python
88 lines
2.6 KiB
Python
"""
|
|
Environment wrapper for D3IL environments with state observations.
|
|
|
|
"""
|
|
|
|
import numpy as np
|
|
import gym
|
|
|
|
|
|
class D3ilLowdimWrapper(gym.Env):
|
|
def __init__(
|
|
self,
|
|
env,
|
|
normalization_path,
|
|
# init_state=None,
|
|
# render_hw=(256, 256),
|
|
# render_camera_name="agentview",
|
|
):
|
|
self.env = env
|
|
# self.init_state = init_state
|
|
# self.render_hw = render_hw
|
|
# self.render_camera_name = render_camera_name
|
|
|
|
# setup spaces
|
|
self.action_space = env.action_space
|
|
self.observation_space = env.observation_space
|
|
normalization = np.load(normalization_path)
|
|
self.obs_min = normalization["obs_min"]
|
|
self.obs_max = normalization["obs_max"]
|
|
self.action_min = normalization["action_min"]
|
|
self.action_max = normalization["action_max"]
|
|
|
|
# def get_observation(self):
|
|
# raw_obs = self.env.get_observation()
|
|
# obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0)
|
|
# return obs
|
|
|
|
def seed(self, seed=None):
|
|
if seed is not None:
|
|
np.random.seed(seed=seed)
|
|
else:
|
|
np.random.seed()
|
|
|
|
def reset(self, **kwargs):
|
|
"""Ignore passed-in arguments like seed"""
|
|
options = kwargs.get("options", {})
|
|
|
|
new_seed = options.get(
|
|
"seed", None
|
|
) # used to set all environments to specified seeds
|
|
# if self.init_state is not None:
|
|
# # always reset to the same state to be compatible with gym
|
|
# self.env.reset_to({"states": self.init_state})
|
|
if new_seed is not None:
|
|
self.seed(seed=new_seed)
|
|
obs = self.env.reset()
|
|
else:
|
|
# random reset
|
|
obs = self.env.reset()
|
|
|
|
# normalize
|
|
obs = self.normalize_obs(obs)
|
|
return obs
|
|
|
|
def normalize_obs(self, obs):
|
|
return 2 * ((obs - self.obs_min) / (self.obs_max - self.obs_min + 1e-6) - 0.5)
|
|
|
|
def unnormaliza_action(self, action):
|
|
action = (action + 1) / 2 # [-1, 1] -> [0, 1]
|
|
return action * (self.action_max - self.action_min) + self.action_min
|
|
|
|
def step(self, action):
|
|
action = self.unnormaliza_action(action)
|
|
obs, reward, done, info = self.env.step(action)
|
|
|
|
# normalize
|
|
obs = self.normalize_obs(obs)
|
|
return obs, reward, done, info
|
|
|
|
def render(self, mode="rgb_array"):
|
|
h, w = self.render_hw
|
|
return self.env.render(
|
|
mode=mode,
|
|
height=h,
|
|
width=w,
|
|
camera_name=self.render_camera_name,
|
|
)
|