248 lines
8.0 KiB
Python
248 lines
8.0 KiB
Python
|
from collections import OrderedDict
|
||
|
import os
|
||
|
|
||
|
|
||
|
from gym import error, spaces
|
||
|
from gym.utils import seeding
|
||
|
import numpy as np
|
||
|
from os import path
|
||
|
import gym
|
||
|
|
||
|
try:
|
||
|
import mujoco_py
|
||
|
except ImportError as e:
|
||
|
raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e))
|
||
|
|
||
|
DEFAULT_SIZE = 500
|
||
|
|
||
|
|
||
|
def convert_observation_to_space(observation):
|
||
|
if isinstance(observation, dict):
|
||
|
space = spaces.Dict(OrderedDict([
|
||
|
(key, convert_observation_to_space(value))
|
||
|
for key, value in observation.items()
|
||
|
]))
|
||
|
elif isinstance(observation, np.ndarray):
|
||
|
low = np.full(observation.shape, -float('inf'), dtype=np.float32)
|
||
|
high = np.full(observation.shape, float('inf'), dtype=np.float32)
|
||
|
space = spaces.Box(low, high, dtype=observation.dtype)
|
||
|
else:
|
||
|
raise NotImplementedError(type(observation), observation)
|
||
|
|
||
|
return space
|
||
|
|
||
|
|
||
|
class AlrMujocoEnv(gym.Env):
|
||
|
"""
|
||
|
Superclass for all MuJoCo environments.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, model_path, n_substeps, apply_gravity_comp=True):
|
||
|
"""
|
||
|
|
||
|
Args:
|
||
|
model_path: path to xml file
|
||
|
n_substeps: how many steps mujoco does per call to env.step
|
||
|
use_servo: use actuator defined in xml, use False for direct torque control
|
||
|
"""
|
||
|
if model_path.startswith("/"):
|
||
|
fullpath = model_path
|
||
|
else:
|
||
|
fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path)
|
||
|
if not path.exists(fullpath):
|
||
|
raise IOError("File %s does not exist" % fullpath)
|
||
|
self.n_substeps = n_substeps
|
||
|
self.apply_gravity_comp = apply_gravity_comp
|
||
|
self.model = mujoco_py.load_model_from_path(fullpath)
|
||
|
self.sim = mujoco_py.MjSim(self.model, nsubsteps=n_substeps)
|
||
|
self.data = self.sim.data
|
||
|
self.viewer = None
|
||
|
self._viewers = {}
|
||
|
|
||
|
self.metadata = {
|
||
|
'render.modes': ['human', 'rgb_array', 'depth_array'],
|
||
|
'video.frames_per_second': int(np.round(1.0 / self.dt))
|
||
|
}
|
||
|
|
||
|
self.init_qpos = self.sim.data.qpos.ravel().copy()
|
||
|
self.init_qvel = self.sim.data.qvel.ravel().copy()
|
||
|
|
||
|
self._set_action_space()
|
||
|
|
||
|
# action = self.action_space.sample()
|
||
|
# observation, _reward, done, _info = self.step(action)
|
||
|
# assert not done
|
||
|
|
||
|
observation = self.reset()
|
||
|
|
||
|
self._set_observation_space(observation)
|
||
|
|
||
|
self.seed()
|
||
|
|
||
|
@property
|
||
|
def current_pos(self):
|
||
|
"""
|
||
|
By default returns the joint positions of all simulated objects. May be overriden in subclass.
|
||
|
"""
|
||
|
return self.sim.data.qpos
|
||
|
|
||
|
@property
|
||
|
def current_vel(self):
|
||
|
"""
|
||
|
By default returns the joint velocities of all simulated objects. May be overriden in subclass.
|
||
|
"""
|
||
|
return self.sim.data.qvel
|
||
|
|
||
|
def extend_des_pos(self, des_pos):
|
||
|
"""
|
||
|
In a simplified environment, the actions may only control a subset of all the joints.
|
||
|
Extend the trajectory to match the environments full action space
|
||
|
Args:
|
||
|
des_pos:
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
def extend_des_vel(self, des_vel):
|
||
|
pass
|
||
|
|
||
|
def _set_action_space(self):
|
||
|
bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
|
||
|
low, high = bounds.T
|
||
|
self.action_space = spaces.Box(low=low, high=high, dtype=np.float32)
|
||
|
return self.action_space
|
||
|
|
||
|
def _set_observation_space(self, observation):
|
||
|
self.observation_space = convert_observation_to_space(observation)
|
||
|
return self.observation_space
|
||
|
|
||
|
def seed(self, seed=None):
|
||
|
self.np_random, seed = seeding.np_random(seed)
|
||
|
return [seed]
|
||
|
|
||
|
# methods to override:
|
||
|
# ----------------------------
|
||
|
|
||
|
def configure(self, *args, **kwargs):
|
||
|
"""
|
||
|
Helper method to set certain environment properties such as contexts in contextual environments since reset()
|
||
|
doesn't take arguments. Should be called before/after reset(). TODO: before or after?
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
def reset_model(self):
|
||
|
"""
|
||
|
Reset the robot degrees of freedom (qpos and qvel).
|
||
|
Implement this in each subclass.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def viewer_setup(self):
|
||
|
"""
|
||
|
This method is called when the viewer is initialized.
|
||
|
Optionally implement this method, if you need to tinker with camera position
|
||
|
and so forth.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
# -----------------------------
|
||
|
|
||
|
def reset(self):
|
||
|
self.sim.reset()
|
||
|
ob = self.reset_model()
|
||
|
return ob
|
||
|
|
||
|
def set_state(self, qpos, qvel):
|
||
|
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
|
||
|
old_state = self.sim.get_state()
|
||
|
new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel,
|
||
|
old_state.act, old_state.udd_state)
|
||
|
self.sim.set_state(new_state)
|
||
|
self.sim.forward()
|
||
|
|
||
|
@property
|
||
|
def dt(self):
|
||
|
return self.model.opt.timestep * self.n_substeps
|
||
|
|
||
|
def do_simulation(self, ctrl):
|
||
|
"""
|
||
|
Additionally returns whether there was an error while stepping the simulation
|
||
|
"""
|
||
|
error_in_sim = False
|
||
|
num_actuations = len(ctrl)
|
||
|
if self.apply_gravity_comp:
|
||
|
self.sim.data.ctrl[:num_actuations] = ctrl + self.sim.data.qfrc_bias[:num_actuations].copy() / self.model.actuator_gear[:, 0]
|
||
|
else:
|
||
|
self.sim.data.ctrl[:num_actuations] = ctrl
|
||
|
|
||
|
try:
|
||
|
self.sim.step()
|
||
|
except mujoco_py.builder.MujocoException as e:
|
||
|
error_in_sim = True
|
||
|
|
||
|
return error_in_sim
|
||
|
|
||
|
def render(self,
|
||
|
mode='human',
|
||
|
width=DEFAULT_SIZE,
|
||
|
height=DEFAULT_SIZE,
|
||
|
camera_id=None,
|
||
|
camera_name=None):
|
||
|
if mode == 'rgb_array' or mode == 'depth_array':
|
||
|
if camera_id is not None and camera_name is not None:
|
||
|
raise ValueError("Both `camera_id` and `camera_name` cannot be"
|
||
|
" specified at the same time.")
|
||
|
|
||
|
no_camera_specified = camera_name is None and camera_id is None
|
||
|
if no_camera_specified:
|
||
|
camera_name = 'track'
|
||
|
|
||
|
if camera_id is None and camera_name in self.model._camera_name2id:
|
||
|
camera_id = self.model.camera_name2id(camera_name)
|
||
|
|
||
|
self._get_viewer(mode).render(width, height, camera_id=camera_id)
|
||
|
|
||
|
if mode == 'rgb_array':
|
||
|
# window size used for old mujoco-py:
|
||
|
data = self._get_viewer(mode).read_pixels(width, height, depth=False)
|
||
|
# original image is upside-down, so flip it
|
||
|
return data[::-1, :, :]
|
||
|
elif mode == 'depth_array':
|
||
|
self._get_viewer(mode).render(width, height)
|
||
|
# window size used for old mujoco-py:
|
||
|
# Extract depth part of the read_pixels() tuple
|
||
|
data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1]
|
||
|
# original image is upside-down, so flip it
|
||
|
return data[::-1, :]
|
||
|
elif mode == 'human':
|
||
|
self._get_viewer(mode).render()
|
||
|
|
||
|
def close(self):
|
||
|
if self.viewer is not None:
|
||
|
# self.viewer.finish()
|
||
|
self.viewer = None
|
||
|
self._viewers = {}
|
||
|
|
||
|
def _get_viewer(self, mode):
|
||
|
self.viewer = self._viewers.get(mode)
|
||
|
if self.viewer is None:
|
||
|
if mode == 'human':
|
||
|
self.viewer = mujoco_py.MjViewer(self.sim)
|
||
|
elif mode == 'rgb_array' or mode == 'depth_array':
|
||
|
self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1)
|
||
|
|
||
|
self.viewer_setup()
|
||
|
self._viewers[mode] = self.viewer
|
||
|
return self.viewer
|
||
|
|
||
|
def get_body_com(self, body_name):
|
||
|
return self.data.get_body_xpos(body_name)
|
||
|
|
||
|
def state_vector(self):
|
||
|
return np.concatenate([
|
||
|
self.sim.data.qpos.flat,
|
||
|
self.sim.data.qvel.flat
|
||
|
])
|