Updated Ball in a cup example for new wrappers
This commit is contained in:
parent
dffa3e3682
commit
c0e036b2e5
@ -7,10 +7,10 @@ from alr_envs.mujoco import alr_mujoco_env
|
|||||||
class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
||||||
def __init__(self, n_substeps=4, apply_gravity_comp=True, simplified: bool = False,
|
def __init__(self, n_substeps=4, apply_gravity_comp=True, simplified: bool = False,
|
||||||
reward_type: str = None, context: np.ndarray = None):
|
reward_type: str = None, context: np.ndarray = None):
|
||||||
|
utils.EzPickle.__init__(**locals())
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "biac_base.xml")
|
||||||
"biac_base" + ".xml")
|
|
||||||
|
|
||||||
self._q_pos = []
|
self._q_pos = []
|
||||||
self._q_vel = []
|
self._q_vel = []
|
||||||
@ -22,7 +22,6 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self.context = context
|
self.context = context
|
||||||
|
|
||||||
utils.EzPickle.__init__(self)
|
|
||||||
alr_mujoco_env.AlrMujocoEnv.__init__(self,
|
alr_mujoco_env.AlrMujocoEnv.__init__(self,
|
||||||
self.xml_path,
|
self.xml_path,
|
||||||
apply_gravity_comp=apply_gravity_comp,
|
apply_gravity_comp=apply_gravity_comp,
|
||||||
@ -194,4 +193,3 @@ if __name__ == "__main__":
|
|||||||
break
|
break
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
34
alr_envs/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py
Normal file
34
alr_envs/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class ALRBallInACupMPWrapper(MPEnvWrapper):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
# TODO: @Max Filter observations correctly
|
||||||
|
return np.hstack([
|
||||||
|
[False] * 7, # cos
|
||||||
|
[False] * 7, # sin
|
||||||
|
# [True] * 2, # x-y coordinates of target distance
|
||||||
|
[False] # env steps
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_pos(self):
|
||||||
|
if self.simplified:
|
||||||
|
return self._start_pos[1::2]
|
||||||
|
else:
|
||||||
|
return self._start_pos
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self):
|
||||||
|
# TODO: @Max I think the default value of returning to the start is reasonable here
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
@ -0,0 +1,15 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api.envs.positional_env_wrapper import PositionalEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class BallInACupPositionalEnvWrapper(PositionalEnvWrapper):
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
return self.sim.data.qpos[0:7].copy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
return self.sim.data.qvel[0:7].copy()
|
@ -1,11 +1,11 @@
|
|||||||
import gym.envs.mujoco
|
|
||||||
import mujoco_py.builder
|
|
||||||
from gym import utils, spaces
|
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
|
||||||
|
import gym.envs.mujoco
|
||||||
import gym.envs.mujoco as mujoco_env
|
import gym.envs.mujoco as mujoco_env
|
||||||
from mp_env_api.envs.mp_env import MpEnv
|
import mujoco_py.builder
|
||||||
from mp_env_api.envs.positional_env import PositionalEnv
|
import numpy as np
|
||||||
|
from gym import utils
|
||||||
|
|
||||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||||
from mp_env_api.utils.policies import PDControllerExtend
|
from mp_env_api.utils.policies import PDControllerExtend
|
||||||
|
|
||||||
@ -18,13 +18,13 @@ def make_detpmp_env(**kwargs):
|
|||||||
return DetPMPWrapper(_env, **kwargs)
|
return DetPMPWrapper(_env, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPickle):
|
class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||||
def __init__(self, frame_skip=4, apply_gravity_comp=True, simplified: bool = False,
|
def __init__(self, frame_skip=4, apply_gravity_comp=True, simplified: bool = False,
|
||||||
reward_type: str = None, context: np.ndarray = None):
|
reward_type: str = None, context: np.ndarray = None):
|
||||||
|
utils.EzPickle.__init__(**locals())
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "biac_base.xml")
|
||||||
"biac_base" + ".xml")
|
|
||||||
|
|
||||||
self.max_ctrl = np.array([150., 125., 40., 60., 5., 5., 2.])
|
self.max_ctrl = np.array([150., 125., 40., 60., 5., 5., 2.])
|
||||||
|
|
||||||
@ -39,6 +39,7 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
|
|||||||
self._start_vel = np.zeros(7)
|
self._start_vel = np.zeros(7)
|
||||||
|
|
||||||
self.sim_time = 8 # seconds
|
self.sim_time = 8 # seconds
|
||||||
|
self._dt = 0.02
|
||||||
self.ep_length = 4000 # based on 8 seconds with dt = 0.02 int(self.sim_time / self.dt)
|
self.ep_length = 4000 # based on 8 seconds with dt = 0.02 int(self.sim_time / self.dt)
|
||||||
if reward_type == "no_context":
|
if reward_type == "no_context":
|
||||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_reward_simple import BallInACupReward
|
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_reward_simple import BallInACupReward
|
||||||
@ -51,15 +52,12 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
|
|||||||
self.reward_function = reward_function(self)
|
self.reward_function = reward_function(self)
|
||||||
|
|
||||||
mujoco_env.MujocoEnv.__init__(self, self.xml_path, frame_skip)
|
mujoco_env.MujocoEnv.__init__(self, self.xml_path, frame_skip)
|
||||||
utils.EzPickle.__init__(self)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_pos(self):
|
def dt(self):
|
||||||
if self.simplified:
|
return self._dt
|
||||||
return self._start_pos[1::2]
|
|
||||||
else:
|
|
||||||
return self._start_pos
|
|
||||||
|
|
||||||
|
# TODO: @Max is this even needed?
|
||||||
@property
|
@property
|
||||||
def start_vel(self):
|
def start_vel(self):
|
||||||
if self.simplified:
|
if self.simplified:
|
||||||
@ -67,14 +65,6 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
|
|||||||
else:
|
else:
|
||||||
return self._start_vel
|
return self._start_vel
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self):
|
|
||||||
return self.sim.data.qpos[0:7].copy()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_vel(self):
|
|
||||||
return self.sim.data.qvel[0:7].copy()
|
|
||||||
|
|
||||||
# def _set_action_space(self):
|
# def _set_action_space(self):
|
||||||
# if self.simplified:
|
# if self.simplified:
|
||||||
# bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)[1::2]
|
# bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)[1::2]
|
||||||
@ -115,10 +105,11 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
|
|||||||
# a = tmp
|
# a = tmp
|
||||||
|
|
||||||
if self.apply_gravity_comp:
|
if self.apply_gravity_comp:
|
||||||
a = a + self.sim.data.qfrc_bias[:len(a)].copy() / self.model.actuator_gear[:, 0]
|
a += self.sim.data.qfrc_bias[:len(a)].copy() / self.model.actuator_gear[:, 0]
|
||||||
|
|
||||||
|
crash = False
|
||||||
try:
|
try:
|
||||||
self.do_simulation(a, self.frame_skip)
|
self.do_simulation(a, self.frame_skip)
|
||||||
crash = False
|
|
||||||
except mujoco_py.builder.MujocoException:
|
except mujoco_py.builder.MujocoException:
|
||||||
crash = True
|
crash = True
|
||||||
# joint_cons_viol = self.check_traj_in_joint_limits()
|
# joint_cons_viol = self.check_traj_in_joint_limits()
|
||||||
@ -158,16 +149,6 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
|
|||||||
[self._steps],
|
[self._steps],
|
||||||
])
|
])
|
||||||
|
|
||||||
# TODO
|
|
||||||
@property
|
|
||||||
def active_obs(self):
|
|
||||||
return np.hstack([
|
|
||||||
[False] * 7, # cos
|
|
||||||
[False] * 7, # sin
|
|
||||||
# [True] * 2, # x-y coordinates of target distance
|
|
||||||
[False] # env steps
|
|
||||||
])
|
|
||||||
|
|
||||||
# These functions are for the task with 3 joint actuations
|
# These functions are for the task with 3 joint actuations
|
||||||
def extend_des_pos(self, des_pos):
|
def extend_des_pos(self, des_pos):
|
||||||
des_pos_full = self._start_pos.copy()
|
des_pos_full = self._start_pos.copy()
|
||||||
@ -222,4 +203,3 @@ if __name__ == "__main__":
|
|||||||
break
|
break
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user