Updated Ball in a cup example for new wrappers

This commit is contained in:
ottofabian 2021-06-24 16:37:41 +02:00
parent dffa3e3682
commit c0e036b2e5
4 changed files with 67 additions and 40 deletions

View File

@ -7,10 +7,10 @@ from alr_envs.mujoco import alr_mujoco_env
class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle): class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
def __init__(self, n_substeps=4, apply_gravity_comp=True, simplified: bool = False, def __init__(self, n_substeps=4, apply_gravity_comp=True, simplified: bool = False,
reward_type: str = None, context: np.ndarray = None): reward_type: str = None, context: np.ndarray = None):
utils.EzPickle.__init__(**locals())
self._steps = 0 self._steps = 0
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "biac_base.xml")
"biac_base" + ".xml")
self._q_pos = [] self._q_pos = []
self._q_vel = [] self._q_vel = []
@ -22,7 +22,6 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
self.context = context self.context = context
utils.EzPickle.__init__(self)
alr_mujoco_env.AlrMujocoEnv.__init__(self, alr_mujoco_env.AlrMujocoEnv.__init__(self,
self.xml_path, self.xml_path,
apply_gravity_comp=apply_gravity_comp, apply_gravity_comp=apply_gravity_comp,
@ -194,4 +193,3 @@ if __name__ == "__main__":
break break
env.close() env.close()

View File

@ -0,0 +1,34 @@
from typing import Union
import numpy as np
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
class ALRBallInACupMPWrapper(MPEnvWrapper):
@property
def active_obs(self):
# TODO: @Max Filter observations correctly
return np.hstack([
[False] * 7, # cos
[False] * 7, # sin
# [True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def start_pos(self):
if self.simplified:
return self._start_pos[1::2]
else:
return self._start_pos
@property
def goal_pos(self):
# TODO: @Max I think the default value of returning to the start is reasonable here
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -0,0 +1,15 @@
from typing import Tuple, Union
import numpy as np
from mp_env_api.envs.positional_env_wrapper import PositionalEnvWrapper
class BallInACupPositionalEnvWrapper(PositionalEnvWrapper):
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.sim.data.qpos[0:7].copy()
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.sim.data.qvel[0:7].copy()

View File

@ -1,11 +1,11 @@
import gym.envs.mujoco
import mujoco_py.builder
from gym import utils, spaces
import os import os
import numpy as np
import gym.envs.mujoco
import gym.envs.mujoco as mujoco_env import gym.envs.mujoco as mujoco_env
from mp_env_api.envs.mp_env import MpEnv import mujoco_py.builder
from mp_env_api.envs.positional_env import PositionalEnv import numpy as np
from gym import utils
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
from mp_env_api.utils.policies import PDControllerExtend from mp_env_api.utils.policies import PDControllerExtend
@ -18,13 +18,13 @@ def make_detpmp_env(**kwargs):
return DetPMPWrapper(_env, **kwargs) return DetPMPWrapper(_env, **kwargs)
class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPickle): class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self, frame_skip=4, apply_gravity_comp=True, simplified: bool = False, def __init__(self, frame_skip=4, apply_gravity_comp=True, simplified: bool = False,
reward_type: str = None, context: np.ndarray = None): reward_type: str = None, context: np.ndarray = None):
utils.EzPickle.__init__(**locals())
self._steps = 0 self._steps = 0
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "biac_base.xml")
"biac_base" + ".xml")
self.max_ctrl = np.array([150., 125., 40., 60., 5., 5., 2.]) self.max_ctrl = np.array([150., 125., 40., 60., 5., 5., 2.])
@ -39,6 +39,7 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
self._start_vel = np.zeros(7) self._start_vel = np.zeros(7)
self.sim_time = 8 # seconds self.sim_time = 8 # seconds
self._dt = 0.02
self.ep_length = 4000 # based on 8 seconds with dt = 0.02 int(self.sim_time / self.dt) self.ep_length = 4000 # based on 8 seconds with dt = 0.02 int(self.sim_time / self.dt)
if reward_type == "no_context": if reward_type == "no_context":
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_reward_simple import BallInACupReward from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_reward_simple import BallInACupReward
@ -51,15 +52,12 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
self.reward_function = reward_function(self) self.reward_function = reward_function(self)
mujoco_env.MujocoEnv.__init__(self, self.xml_path, frame_skip) mujoco_env.MujocoEnv.__init__(self, self.xml_path, frame_skip)
utils.EzPickle.__init__(self)
@property @property
def start_pos(self): def dt(self):
if self.simplified: return self._dt
return self._start_pos[1::2]
else:
return self._start_pos
# TODO: @Max is this even needed?
@property @property
def start_vel(self): def start_vel(self):
if self.simplified: if self.simplified:
@ -67,14 +65,6 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
else: else:
return self._start_vel return self._start_vel
@property
def current_pos(self):
return self.sim.data.qpos[0:7].copy()
@property
def current_vel(self):
return self.sim.data.qvel[0:7].copy()
# def _set_action_space(self): # def _set_action_space(self):
# if self.simplified: # if self.simplified:
# bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)[1::2] # bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)[1::2]
@ -115,10 +105,11 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
# a = tmp # a = tmp
if self.apply_gravity_comp: if self.apply_gravity_comp:
a = a + self.sim.data.qfrc_bias[:len(a)].copy() / self.model.actuator_gear[:, 0] a += self.sim.data.qfrc_bias[:len(a)].copy() / self.model.actuator_gear[:, 0]
crash = False
try: try:
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
crash = False
except mujoco_py.builder.MujocoException: except mujoco_py.builder.MujocoException:
crash = True crash = True
# joint_cons_viol = self.check_traj_in_joint_limits() # joint_cons_viol = self.check_traj_in_joint_limits()
@ -158,16 +149,6 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
[self._steps], [self._steps],
]) ])
# TODO
@property
def active_obs(self):
return np.hstack([
[False] * 7, # cos
[False] * 7, # sin
# [True] * 2, # x-y coordinates of target distance
[False] # env steps
])
# These functions are for the task with 3 joint actuations # These functions are for the task with 3 joint actuations
def extend_des_pos(self, des_pos): def extend_des_pos(self, des_pos):
des_pos_full = self._start_pos.copy() des_pos_full = self._start_pos.copy()
@ -222,4 +203,3 @@ if __name__ == "__main__":
break break
env.close() env.close()