Updated Ball in a cup example for new wrappers
This commit is contained in:
parent
dffa3e3682
commit
c0e036b2e5
@ -7,10 +7,10 @@ from alr_envs.mujoco import alr_mujoco_env
|
||||
class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
||||
def __init__(self, n_substeps=4, apply_gravity_comp=True, simplified: bool = False,
|
||||
reward_type: str = None, context: np.ndarray = None):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
self._steps = 0
|
||||
|
||||
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
||||
"biac_base" + ".xml")
|
||||
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "biac_base.xml")
|
||||
|
||||
self._q_pos = []
|
||||
self._q_vel = []
|
||||
@ -22,7 +22,6 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
||||
|
||||
self.context = context
|
||||
|
||||
utils.EzPickle.__init__(self)
|
||||
alr_mujoco_env.AlrMujocoEnv.__init__(self,
|
||||
self.xml_path,
|
||||
apply_gravity_comp=apply_gravity_comp,
|
||||
@ -194,4 +193,3 @@ if __name__ == "__main__":
|
||||
break
|
||||
|
||||
env.close()
|
||||
|
||||
|
34
alr_envs/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py
Normal file
34
alr_envs/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py
Normal file
@ -0,0 +1,34 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
||||
|
||||
|
||||
class ALRBallInACupMPWrapper(MPEnvWrapper):
|
||||
|
||||
@property
|
||||
def active_obs(self):
|
||||
# TODO: @Max Filter observations correctly
|
||||
return np.hstack([
|
||||
[False] * 7, # cos
|
||||
[False] * 7, # sin
|
||||
# [True] * 2, # x-y coordinates of target distance
|
||||
[False] # env steps
|
||||
])
|
||||
|
||||
@property
|
||||
def start_pos(self):
|
||||
if self.simplified:
|
||||
return self._start_pos[1::2]
|
||||
else:
|
||||
return self._start_pos
|
||||
|
||||
@property
|
||||
def goal_pos(self):
|
||||
# TODO: @Max I think the default value of returning to the start is reasonable here
|
||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
@ -0,0 +1,15 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.envs.positional_env_wrapper import PositionalEnvWrapper
|
||||
|
||||
|
||||
class BallInACupPositionalEnvWrapper(PositionalEnvWrapper):
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return self.sim.data.qpos[0:7].copy()
|
||||
|
||||
@property
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return self.sim.data.qvel[0:7].copy()
|
@ -1,11 +1,11 @@
|
||||
import gym.envs.mujoco
|
||||
import mujoco_py.builder
|
||||
from gym import utils, spaces
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import gym.envs.mujoco
|
||||
import gym.envs.mujoco as mujoco_env
|
||||
from mp_env_api.envs.mp_env import MpEnv
|
||||
from mp_env_api.envs.positional_env import PositionalEnv
|
||||
import mujoco_py.builder
|
||||
import numpy as np
|
||||
from gym import utils
|
||||
|
||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||
from mp_env_api.utils.policies import PDControllerExtend
|
||||
|
||||
@ -18,13 +18,13 @@ def make_detpmp_env(**kwargs):
|
||||
return DetPMPWrapper(_env, **kwargs)
|
||||
|
||||
|
||||
class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPickle):
|
||||
class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
def __init__(self, frame_skip=4, apply_gravity_comp=True, simplified: bool = False,
|
||||
reward_type: str = None, context: np.ndarray = None):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
self._steps = 0
|
||||
|
||||
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
||||
"biac_base" + ".xml")
|
||||
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "biac_base.xml")
|
||||
|
||||
self.max_ctrl = np.array([150., 125., 40., 60., 5., 5., 2.])
|
||||
|
||||
@ -39,6 +39,7 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
|
||||
self._start_vel = np.zeros(7)
|
||||
|
||||
self.sim_time = 8 # seconds
|
||||
self._dt = 0.02
|
||||
self.ep_length = 4000 # based on 8 seconds with dt = 0.02 int(self.sim_time / self.dt)
|
||||
if reward_type == "no_context":
|
||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_reward_simple import BallInACupReward
|
||||
@ -51,15 +52,12 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
|
||||
self.reward_function = reward_function(self)
|
||||
|
||||
mujoco_env.MujocoEnv.__init__(self, self.xml_path, frame_skip)
|
||||
utils.EzPickle.__init__(self)
|
||||
|
||||
@property
|
||||
def start_pos(self):
|
||||
if self.simplified:
|
||||
return self._start_pos[1::2]
|
||||
else:
|
||||
return self._start_pos
|
||||
def dt(self):
|
||||
return self._dt
|
||||
|
||||
# TODO: @Max is this even needed?
|
||||
@property
|
||||
def start_vel(self):
|
||||
if self.simplified:
|
||||
@ -67,14 +65,6 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
|
||||
else:
|
||||
return self._start_vel
|
||||
|
||||
@property
|
||||
def current_pos(self):
|
||||
return self.sim.data.qpos[0:7].copy()
|
||||
|
||||
@property
|
||||
def current_vel(self):
|
||||
return self.sim.data.qvel[0:7].copy()
|
||||
|
||||
# def _set_action_space(self):
|
||||
# if self.simplified:
|
||||
# bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)[1::2]
|
||||
@ -115,10 +105,11 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
|
||||
# a = tmp
|
||||
|
||||
if self.apply_gravity_comp:
|
||||
a = a + self.sim.data.qfrc_bias[:len(a)].copy() / self.model.actuator_gear[:, 0]
|
||||
a += self.sim.data.qfrc_bias[:len(a)].copy() / self.model.actuator_gear[:, 0]
|
||||
|
||||
crash = False
|
||||
try:
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
crash = False
|
||||
except mujoco_py.builder.MujocoException:
|
||||
crash = True
|
||||
# joint_cons_viol = self.check_traj_in_joint_limits()
|
||||
@ -158,16 +149,6 @@ class ALRBallInACupPDEnv(mujoco_env.MujocoEnv, PositionalEnv, MpEnv, utils.EzPic
|
||||
[self._steps],
|
||||
])
|
||||
|
||||
# TODO
|
||||
@property
|
||||
def active_obs(self):
|
||||
return np.hstack([
|
||||
[False] * 7, # cos
|
||||
[False] * 7, # sin
|
||||
# [True] * 2, # x-y coordinates of target distance
|
||||
[False] # env steps
|
||||
])
|
||||
|
||||
# These functions are for the task with 3 joint actuations
|
||||
def extend_des_pos(self, des_pos):
|
||||
des_pos_full = self._start_pos.copy()
|
||||
@ -222,4 +203,3 @@ if __name__ == "__main__":
|
||||
break
|
||||
|
||||
env.close()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user